MMGeneration | PyTorch 零基础入门 GAN 模型

简介: GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。

640.png

快乐搬砖的周一又如期而至

度过了幸福洋溢的周末(狗头)

想必大家一定干劲充沛!

640.png

不知七夕大家成功检测到对象没?

没有检测到也别着急

小编这不是带着 MMGeneration 来了吗!

640.png


那么怎么生成对象呢? 对象是什么,对象怎么生成?今天小编就来带领大家学习一下到底怎么生成对象,关于生成对象的事,想必大家都有所了解,什么一见钟情啦、寤寐思服啦,都二三十岁的人了还不知道可遇不可求,做白日梦想桃子吃呐? 以上就是生成对象的含义和方法,希望小编精心整理的这篇内容能够解决你的困惑。


640.png


要想学习用 MMGeneration 生成对象,首先要了解如何使用和训练各种生成模型


近年来,各种生成模型及其应用广泛地出现在大家的视野范围内,像最近非常火爆的 Alias-Free GAN 更是从一个全新的视角,为生成模型领域中新的发展方向打下了坚实的理论基础。但是现在来看,无论是之前的 StyleGAN2还是现在的 Alias-Free GAN,模型细节还有训练过程都是非常繁杂的。


640.png


同时,如果再结合 PyTorch 的话,又需要考虑各种分布式训练的问题。在这种内卷和快速迭代的时代,如何快速上手,把握住机会呢?


在这一系列的教程里面呢,我们将以 MMGeneration 为基础,来帮助大家快速入门 GAN 这个庞大的领域。选择 MMGeneration 是因为,在前期,你可以不需要任何 PyTorch 的基础,到后期熟练之后,你也只需要在对应模型上进行一些修改就可以轻松地上手做一些实验啦。这么好用的工具包,还不来 Star 一波?


本文内容

1. 从 GAN 到 DCGAN

2. 模型结构和代码分析

3. DCGAN CelebA 实验


1. 入门第一步:从 GAN 到 DCGAN



GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。看起来就像是两人练武,虽然一开始大家都很菜,但是判别器进步一点,然后生成器就迎头赶上,然后慢慢地两个人携手成为一代宗师。


这是一种很抽象的理解,具体到数学理论上,GAN 的思想是由 Ian J. Goodfellow 在 Generative Adversarial Nets 中完整提出并证明的。

640.png

整个 GAN 的对抗思想就体现在上图所示的损失函数公式当中,下面我们一点一点解析这个公式。


简析损失函数公式


在公式中 G(z) 描述了生成器的工作方式:输入一个噪声信号,然后输出一个尽可能逼真的图片(样本)。


D(x) 和 D(G(z)) 表示判别器输入的分别是真实数据集中的图片和生成的图片,判别器的任务就是需要尽可能将两者区分开(在这里可以看作一个简单的二分类问题),这也就是让 D 去最大化这个损失函数的意义。


而 G 就需要尽可能让 D 无法分辨出真实图片和虚假图片的差异,换句话说期望 G 生成的图片可以以假乱真。

640.png

上图展示具体的训练流程,可以看到,在训练过程中 G 和 D 的优化是交替进行的,而且一般情况下我们往往希望 D 学习的稍微快一点,这样能够带动 G 更好地朝着全局最优解方向优化。具体的理论推导,大家可以去参考Generative Adversarial Networks 一文,在这里通过 KL 散度可以很非常直观地看到我们的优化目标最优解就是生成器能够完全拟合真实样本的分布。

在 Ian J. Goodfellow 的文章中,主要是提出了 GAN 的思想。可是面对图像,我们常用的算子是卷积层。在 DCGAN 一文中,作者将 transposed convolution 用到了生成器中,成为了一个里程碑式的工作。从此之后,以卷积神经网络为主体的 GAN 模型就不断地涌现了出来。那下面我们将从这个基础模型入手,来看一下 GAN 网络的构造。


2. 模型结构和代码分析



640.png

上图就展示了一个典型的生成器和判别器的结构。在生成器当中,首先我们需要一个将噪声向量转换成为二维特征的模块,也就是 noise2feat block。


那接下来需要连续经过几个上采样块将低分辨率的特征转换成高分辨的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。


最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。

对应到 MMGeneration 当中,模型的具体代码都存放在 models/architectures/dcgan 文件夹下面,下图展示了对应的生成器和判别器的代码逻辑,在 mmgen 中我们也是严格按照这样的设计来构建代码,相信大家能够更容易上手。


如果关心具体实现的同学,可以到文件中查看,如果你暂时还是 PyTorch 初学者,那你大可不必关心具体的实现,我们接下来告诉大家怎么用 mmgen 训练一个 DCGAN。

640.png


3. DCGAN CelebA 实验



接下来,将通过 MMGeneration 来一步一步详细地教大家如何训练第一个 DCGAN 模型。这里不需要大家有太多的 PyTorch 基础知识,只需要跟着我们一步一步来就可以了。


Step 1 安装


使用 MMGeneration,你只需要克隆一下 github 上面的仓库到本地,然后按照安装手册配置一下环境即可,如果安装遇到什么问题,可以给 MMGeneration 提 issue,或是加入 OpenMMLab 社区微信群提问(入群方式见文末),我们会尽快为小伙伴们解答。下面是具体的安装步骤:

# we assume that you have installed pytorch and mmcv-full in your env.
# clone repo
git clone https://github.com/open-mmlab/mmgeneration mmgen
cd mmgen
# install mmgen
pip install -e .


Step 2 数据


假设大家已经安装好了 MMGeneration,回到训练上来,首先我们要做的是准备训练数据,CelebA 的数据可以通过其官方网站下载,我们选用其中的 Align&Cropped 数据来进行训练。下载解压完了之后,我们需要回到 MMGeneration 仓库的文件夹,通过软链的方式将数据链接到仓库的 data 目录下面:

mkdir data
ln -s absolute_path_to_CelebA ./data/celeba

这样我们的数据准备工作就基本完成了。不过我们需要再更新一下我们的 config 文件的中的 img_root 字段,将我们现在的数据路径更新上去。具体要做的就是修改 dcgan-celeba config 文件的第 11 行

# define dataset
# you must set `samples_per_gpu` and `imgs_root`
data = dict(
    samples_per_gpu=128,
    train=dict(imgs_root='data/celeba'))  # set img_root

这个 config 文件其实就能帮我们定义整个训练的过程,包括数据集的构造,模型的定义以及训练流程的定义等等,详细的介绍后续会带给大家。大家现在可以先通过我们提供的 config 来实现快速的上手训练采样生成图片


Step 3 训练


训练的指令其实非常简单,通过我们之前修改的 config 文件,我们就可以通过如下命令进行训练了:

bash tools/dist_train.sh ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py 1 --work-dir ./work_dirs/dcgan-celeba

在训练过程中,我们会自动保存不同阶段模型生成的样本到 work_dirs/dcgan-celeba 文件夹下面:

640.gif

这样我们就可以随时观测到模型的收敛情况了,当然后续的教程里面,我们还会介绍如何通过一些客观的评价指标来检测我们的训练过程。


训练完成之后,我们就可以通过随机采样来看看模型能带给我们什么样的样本啦。在 MMGeneration 当中,我们可以轻松得通过使用demo/unconditional_demo.py 来实现:

python demo/unconditional_demo.py ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py work_dirs/dcgan-celeba/ckpt/iter_290000.pth


其实在 MMGeneration 当中已经支持了非常多模型的采样,并且提供了公共的 checkpoint 供大家把玩,在我们的快速上手教程中有更详细地介绍(可联系小运营获取快速上手教程),欢迎大家来试用并且提出你们宝贵的意见。


总之,活用 MMGeneration

不仅仅能生成对象哦!

心动不如行动,快一起试一下吧!

640.png


下 *N 期预告

# MMGeneration # 真 · 生成对象


MMGeneration 开发者说

如果 Github 上 OpenMMLab 的 Star

增加超过 1314 就安排

兄弟们,机会来了靠自己了!


文章来源:公众号【OpenMMLab】

2021-08-16 19:00


目录
相关文章
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于AlexNet的预训练模型介绍
【PyTorch实战演练】基于AlexNet的预训练模型介绍
85 0
|
3月前
|
机器学习/深度学习 并行计算 PyTorch
TensorRT部署系列 | 如何将模型从 PyTorch 转换为 TensorRT 并加速推理?
TensorRT部署系列 | 如何将模型从 PyTorch 转换为 TensorRT 并加速推理?
162 0
|
1月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
106 4
|
2月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
15天前
|
PyTorch 算法框架/工具 Python
【pytorch框架】对模型知识的基本了解
【pytorch框架】对模型知识的基本了解
|
25天前
|
机器学习/深度学习 算法 PyTorch
PyTorch模型优化与调优:正则化、批归一化等技巧
【4月更文挑战第18天】本文探讨了PyTorch中提升模型性能的优化技巧,包括正则化(L1/L2正则化、Dropout)、批归一化、学习率调整策略和模型架构优化。正则化防止过拟合,Dropout提高泛化能力;批归一化加速训练并提升性能;学习率调整策略动态优化训练效果;模型架构优化涉及网络结构和参数的调整。这些方法有助于实现更高效的深度学习模型。
|
25天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
2月前
|
PyTorch 算法框架/工具 Python
Pytorch构建网络模型时super(__class__, self).__init__()的作用
Pytorch构建网络模型时super(__class__, self).__init__()的作用
12 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch用GAN生成手写数字实例(附代码)
基于Pytorch用GAN生成手写数字实例(附代码)
37 0
|
2月前
|
PyTorch 算法框架/工具 Python
基于Pytorch的YoLoV4模型代码及作品欣赏
基于Pytorch的YoLoV4模型代码及作品欣赏
30 0