MMGeneration | PyTorch 零基础入门 GAN 模型

简介: GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。

640.png

快乐搬砖的周一又如期而至

度过了幸福洋溢的周末(狗头)

想必大家一定干劲充沛!

640.png

不知七夕大家成功检测到对象没?

没有检测到也别着急

小编这不是带着 MMGeneration 来了吗!

640.png


那么怎么生成对象呢? 对象是什么,对象怎么生成?今天小编就来带领大家学习一下到底怎么生成对象,关于生成对象的事,想必大家都有所了解,什么一见钟情啦、寤寐思服啦,都二三十岁的人了还不知道可遇不可求,做白日梦想桃子吃呐? 以上就是生成对象的含义和方法,希望小编精心整理的这篇内容能够解决你的困惑。


640.png


要想学习用 MMGeneration 生成对象,首先要了解如何使用和训练各种生成模型


近年来,各种生成模型及其应用广泛地出现在大家的视野范围内,像最近非常火爆的 Alias-Free GAN 更是从一个全新的视角,为生成模型领域中新的发展方向打下了坚实的理论基础。但是现在来看,无论是之前的 StyleGAN2还是现在的 Alias-Free GAN,模型细节还有训练过程都是非常繁杂的。


640.png


同时,如果再结合 PyTorch 的话,又需要考虑各种分布式训练的问题。在这种内卷和快速迭代的时代,如何快速上手,把握住机会呢?


在这一系列的教程里面呢,我们将以 MMGeneration 为基础,来帮助大家快速入门 GAN 这个庞大的领域。选择 MMGeneration 是因为,在前期,你可以不需要任何 PyTorch 的基础,到后期熟练之后,你也只需要在对应模型上进行一些修改就可以轻松地上手做一些实验啦。这么好用的工具包,还不来 Star 一波?


本文内容

1. 从 GAN 到 DCGAN

2. 模型结构和代码分析

3. DCGAN CelebA 实验


1. 入门第一步:从 GAN 到 DCGAN



GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。看起来就像是两人练武,虽然一开始大家都很菜,但是判别器进步一点,然后生成器就迎头赶上,然后慢慢地两个人携手成为一代宗师。


这是一种很抽象的理解,具体到数学理论上,GAN 的思想是由 Ian J. Goodfellow 在 Generative Adversarial Nets 中完整提出并证明的。

640.png

整个 GAN 的对抗思想就体现在上图所示的损失函数公式当中,下面我们一点一点解析这个公式。


简析损失函数公式


在公式中 G(z) 描述了生成器的工作方式:输入一个噪声信号,然后输出一个尽可能逼真的图片(样本)。


D(x) 和 D(G(z)) 表示判别器输入的分别是真实数据集中的图片和生成的图片,判别器的任务就是需要尽可能将两者区分开(在这里可以看作一个简单的二分类问题),这也就是让 D 去最大化这个损失函数的意义。


而 G 就需要尽可能让 D 无法分辨出真实图片和虚假图片的差异,换句话说期望 G 生成的图片可以以假乱真。

640.png

上图展示具体的训练流程,可以看到,在训练过程中 G 和 D 的优化是交替进行的,而且一般情况下我们往往希望 D 学习的稍微快一点,这样能够带动 G 更好地朝着全局最优解方向优化。具体的理论推导,大家可以去参考Generative Adversarial Networks 一文,在这里通过 KL 散度可以很非常直观地看到我们的优化目标最优解就是生成器能够完全拟合真实样本的分布。

在 Ian J. Goodfellow 的文章中,主要是提出了 GAN 的思想。可是面对图像,我们常用的算子是卷积层。在 DCGAN 一文中,作者将 transposed convolution 用到了生成器中,成为了一个里程碑式的工作。从此之后,以卷积神经网络为主体的 GAN 模型就不断地涌现了出来。那下面我们将从这个基础模型入手,来看一下 GAN 网络的构造。


2. 模型结构和代码分析



640.png

上图就展示了一个典型的生成器和判别器的结构。在生成器当中,首先我们需要一个将噪声向量转换成为二维特征的模块,也就是 noise2feat block。


那接下来需要连续经过几个上采样块将低分辨率的特征转换成高分辨的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。


最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。

对应到 MMGeneration 当中,模型的具体代码都存放在 models/architectures/dcgan 文件夹下面,下图展示了对应的生成器和判别器的代码逻辑,在 mmgen 中我们也是严格按照这样的设计来构建代码,相信大家能够更容易上手。


如果关心具体实现的同学,可以到文件中查看,如果你暂时还是 PyTorch 初学者,那你大可不必关心具体的实现,我们接下来告诉大家怎么用 mmgen 训练一个 DCGAN。

640.png


3. DCGAN CelebA 实验



接下来,将通过 MMGeneration 来一步一步详细地教大家如何训练第一个 DCGAN 模型。这里不需要大家有太多的 PyTorch 基础知识,只需要跟着我们一步一步来就可以了。


Step 1 安装


使用 MMGeneration,你只需要克隆一下 github 上面的仓库到本地,然后按照安装手册配置一下环境即可,如果安装遇到什么问题,可以给 MMGeneration 提 issue,或是加入 OpenMMLab 社区微信群提问(入群方式见文末),我们会尽快为小伙伴们解答。下面是具体的安装步骤:

# we assume that you have installed pytorch and mmcv-full in your env.
# clone repo
git clone https://github.com/open-mmlab/mmgeneration mmgen
cd mmgen
# install mmgen
pip install -e .


Step 2 数据


假设大家已经安装好了 MMGeneration,回到训练上来,首先我们要做的是准备训练数据,CelebA 的数据可以通过其官方网站下载,我们选用其中的 Align&Cropped 数据来进行训练。下载解压完了之后,我们需要回到 MMGeneration 仓库的文件夹,通过软链的方式将数据链接到仓库的 data 目录下面:

mkdir data
ln -s absolute_path_to_CelebA ./data/celeba

这样我们的数据准备工作就基本完成了。不过我们需要再更新一下我们的 config 文件的中的 img_root 字段,将我们现在的数据路径更新上去。具体要做的就是修改 dcgan-celeba config 文件的第 11 行

# define dataset
# you must set `samples_per_gpu` and `imgs_root`
data = dict(
    samples_per_gpu=128,
    train=dict(imgs_root='data/celeba'))  # set img_root

这个 config 文件其实就能帮我们定义整个训练的过程,包括数据集的构造,模型的定义以及训练流程的定义等等,详细的介绍后续会带给大家。大家现在可以先通过我们提供的 config 来实现快速的上手训练采样生成图片


Step 3 训练


训练的指令其实非常简单,通过我们之前修改的 config 文件,我们就可以通过如下命令进行训练了:

bash tools/dist_train.sh ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py 1 --work-dir ./work_dirs/dcgan-celeba

在训练过程中,我们会自动保存不同阶段模型生成的样本到 work_dirs/dcgan-celeba 文件夹下面:

640.gif

这样我们就可以随时观测到模型的收敛情况了,当然后续的教程里面,我们还会介绍如何通过一些客观的评价指标来检测我们的训练过程。


训练完成之后,我们就可以通过随机采样来看看模型能带给我们什么样的样本啦。在 MMGeneration 当中,我们可以轻松得通过使用demo/unconditional_demo.py 来实现:

python demo/unconditional_demo.py ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py work_dirs/dcgan-celeba/ckpt/iter_290000.pth


其实在 MMGeneration 当中已经支持了非常多模型的采样,并且提供了公共的 checkpoint 供大家把玩,在我们的快速上手教程中有更详细地介绍(可联系小运营获取快速上手教程),欢迎大家来试用并且提出你们宝贵的意见。


总之,活用 MMGeneration

不仅仅能生成对象哦!

心动不如行动,快一起试一下吧!

640.png


下 *N 期预告

# MMGeneration # 真 · 生成对象


MMGeneration 开发者说

如果 Github 上 OpenMMLab 的 Star

增加超过 1314 就安排

兄弟们,机会来了靠自己了!


文章来源:公众号【OpenMMLab】

2021-08-16 19:00


目录
相关文章
|
29天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
138 2
|
1月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
52 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
1月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
48 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
2月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
129 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
2月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
40 3
PyTorch 模型调试与故障排除指南
|
1月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
3月前
|
机器学习/深度学习 PyTorch 编译器
PyTorch 与 TorchScript:模型的序列化与加速
【8月更文第27天】PyTorch 是一个非常流行的深度学习框架,它以其灵活性和易用性而著称。然而,当涉及到模型的部署和性能优化时,PyTorch 的动态计算图可能会带来一些挑战。为了解决这些问题,PyTorch 引入了 TorchScript,这是一个用于序列化和优化 PyTorch 模型的工具。本文将详细介绍如何使用 TorchScript 来序列化 PyTorch 模型以及如何加速模型的执行。
112 4
|
3月前
|
机器学习/深度学习 边缘计算 PyTorch
PyTorch 与边缘计算:将深度学习模型部署到嵌入式设备
【8月更文第29天】随着物联网技术的发展,越来越多的数据处理任务开始在边缘设备上执行,以减少网络延迟、降低带宽成本并提高隐私保护水平。PyTorch 是一个广泛使用的深度学习框架,它不仅支持高效的模型训练,还提供了多种工具帮助开发者将模型部署到边缘设备。本文将探讨如何将PyTorch模型高效地部署到嵌入式设备上,并通过一个具体的示例来展示整个流程。
460 1
|
3月前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch与Hugging Face Transformers:快速构建先进的NLP模型
【8月更文第27天】随着自然语言处理(NLP)技术的快速发展,深度学习模型已经成为了构建高质量NLP应用程序的关键。PyTorch 作为一种强大的深度学习框架,提供了灵活的 API 和高效的性能,非常适合于构建复杂的 NLP 模型。Hugging Face Transformers 库则是目前最流行的预训练模型库之一,它为 PyTorch 提供了大量的预训练模型和工具,极大地简化了模型训练和部署的过程。
160 2
|
3月前
|
机器学习/深度学习 边缘计算 PyTorch
PyTorch 与 ONNX:模型的跨平台部署策略
【8月更文第27天】深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。
188 2
下一篇
无影云桌面