MMGeneration | PyTorch 零基础入门 GAN 模型

简介: GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。

640.png

快乐搬砖的周一又如期而至

度过了幸福洋溢的周末(狗头)

想必大家一定干劲充沛!

640.png

不知七夕大家成功检测到对象没?

没有检测到也别着急

小编这不是带着 MMGeneration 来了吗!

640.png


那么怎么生成对象呢? 对象是什么,对象怎么生成?今天小编就来带领大家学习一下到底怎么生成对象,关于生成对象的事,想必大家都有所了解,什么一见钟情啦、寤寐思服啦,都二三十岁的人了还不知道可遇不可求,做白日梦想桃子吃呐? 以上就是生成对象的含义和方法,希望小编精心整理的这篇内容能够解决你的困惑。


640.png


要想学习用 MMGeneration 生成对象,首先要了解如何使用和训练各种生成模型


近年来,各种生成模型及其应用广泛地出现在大家的视野范围内,像最近非常火爆的 Alias-Free GAN 更是从一个全新的视角,为生成模型领域中新的发展方向打下了坚实的理论基础。但是现在来看,无论是之前的 StyleGAN2还是现在的 Alias-Free GAN,模型细节还有训练过程都是非常繁杂的。


640.png


同时,如果再结合 PyTorch 的话,又需要考虑各种分布式训练的问题。在这种内卷和快速迭代的时代,如何快速上手,把握住机会呢?


在这一系列的教程里面呢,我们将以 MMGeneration 为基础,来帮助大家快速入门 GAN 这个庞大的领域。选择 MMGeneration 是因为,在前期,你可以不需要任何 PyTorch 的基础,到后期熟练之后,你也只需要在对应模型上进行一些修改就可以轻松地上手做一些实验啦。这么好用的工具包,还不来 Star 一波?


本文内容

1. 从 GAN 到 DCGAN

2. 模型结构和代码分析

3. DCGAN CelebA 实验


1. 入门第一步:从 GAN 到 DCGAN



GAN 的全称是 Generative Adversarial Networks,翻译过来就是生成对抗网络。这里我们需要重点理解的是对抗的含义:GAN 的基本想法是两个网络,生成器(G)和判别器(D),在训练过程中相互对抗。看起来就像是两人练武,虽然一开始大家都很菜,但是判别器进步一点,然后生成器就迎头赶上,然后慢慢地两个人携手成为一代宗师。


这是一种很抽象的理解,具体到数学理论上,GAN 的思想是由 Ian J. Goodfellow 在 Generative Adversarial Nets 中完整提出并证明的。

640.png

整个 GAN 的对抗思想就体现在上图所示的损失函数公式当中,下面我们一点一点解析这个公式。


简析损失函数公式


在公式中 G(z) 描述了生成器的工作方式:输入一个噪声信号,然后输出一个尽可能逼真的图片(样本)。


D(x) 和 D(G(z)) 表示判别器输入的分别是真实数据集中的图片和生成的图片,判别器的任务就是需要尽可能将两者区分开(在这里可以看作一个简单的二分类问题),这也就是让 D 去最大化这个损失函数的意义。


而 G 就需要尽可能让 D 无法分辨出真实图片和虚假图片的差异,换句话说期望 G 生成的图片可以以假乱真。

640.png

上图展示具体的训练流程,可以看到,在训练过程中 G 和 D 的优化是交替进行的,而且一般情况下我们往往希望 D 学习的稍微快一点,这样能够带动 G 更好地朝着全局最优解方向优化。具体的理论推导,大家可以去参考Generative Adversarial Networks 一文,在这里通过 KL 散度可以很非常直观地看到我们的优化目标最优解就是生成器能够完全拟合真实样本的分布。

在 Ian J. Goodfellow 的文章中,主要是提出了 GAN 的思想。可是面对图像,我们常用的算子是卷积层。在 DCGAN 一文中,作者将 transposed convolution 用到了生成器中,成为了一个里程碑式的工作。从此之后,以卷积神经网络为主体的 GAN 模型就不断地涌现了出来。那下面我们将从这个基础模型入手,来看一下 GAN 网络的构造。


2. 模型结构和代码分析



640.png

上图就展示了一个典型的生成器和判别器的结构。在生成器当中,首先我们需要一个将噪声向量转换成为二维特征的模块,也就是 noise2feat block。


那接下来需要连续经过几个上采样块将低分辨率的特征转换成高分辨的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。


最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。

对应到 MMGeneration 当中,模型的具体代码都存放在 models/architectures/dcgan 文件夹下面,下图展示了对应的生成器和判别器的代码逻辑,在 mmgen 中我们也是严格按照这样的设计来构建代码,相信大家能够更容易上手。


如果关心具体实现的同学,可以到文件中查看,如果你暂时还是 PyTorch 初学者,那你大可不必关心具体的实现,我们接下来告诉大家怎么用 mmgen 训练一个 DCGAN。

640.png


3. DCGAN CelebA 实验



接下来,将通过 MMGeneration 来一步一步详细地教大家如何训练第一个 DCGAN 模型。这里不需要大家有太多的 PyTorch 基础知识,只需要跟着我们一步一步来就可以了。


Step 1 安装


使用 MMGeneration,你只需要克隆一下 github 上面的仓库到本地,然后按照安装手册配置一下环境即可,如果安装遇到什么问题,可以给 MMGeneration 提 issue,或是加入 OpenMMLab 社区微信群提问(入群方式见文末),我们会尽快为小伙伴们解答。下面是具体的安装步骤:

# we assume that you have installed pytorch and mmcv-full in your env.
# clone repo
git clone https://github.com/open-mmlab/mmgeneration mmgen
cd mmgen
# install mmgen
pip install -e .


Step 2 数据


假设大家已经安装好了 MMGeneration,回到训练上来,首先我们要做的是准备训练数据,CelebA 的数据可以通过其官方网站下载,我们选用其中的 Align&Cropped 数据来进行训练。下载解压完了之后,我们需要回到 MMGeneration 仓库的文件夹,通过软链的方式将数据链接到仓库的 data 目录下面:

mkdir data
ln -s absolute_path_to_CelebA ./data/celeba

这样我们的数据准备工作就基本完成了。不过我们需要再更新一下我们的 config 文件的中的 img_root 字段,将我们现在的数据路径更新上去。具体要做的就是修改 dcgan-celeba config 文件的第 11 行

# define dataset
# you must set `samples_per_gpu` and `imgs_root`
data = dict(
    samples_per_gpu=128,
    train=dict(imgs_root='data/celeba'))  # set img_root

这个 config 文件其实就能帮我们定义整个训练的过程,包括数据集的构造,模型的定义以及训练流程的定义等等,详细的介绍后续会带给大家。大家现在可以先通过我们提供的 config 来实现快速的上手训练采样生成图片


Step 3 训练


训练的指令其实非常简单,通过我们之前修改的 config 文件,我们就可以通过如下命令进行训练了:

bash tools/dist_train.sh ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py 1 --work-dir ./work_dirs/dcgan-celeba

在训练过程中,我们会自动保存不同阶段模型生成的样本到 work_dirs/dcgan-celeba 文件夹下面:

640.gif

这样我们就可以随时观测到模型的收敛情况了,当然后续的教程里面,我们还会介绍如何通过一些客观的评价指标来检测我们的训练过程。


训练完成之后,我们就可以通过随机采样来看看模型能带给我们什么样的样本啦。在 MMGeneration 当中,我们可以轻松得通过使用demo/unconditional_demo.py 来实现:

python demo/unconditional_demo.py ./configs/dcgan/dcgan_celeba-cropped_64_b128x1_300k.py work_dirs/dcgan-celeba/ckpt/iter_290000.pth


其实在 MMGeneration 当中已经支持了非常多模型的采样,并且提供了公共的 checkpoint 供大家把玩,在我们的快速上手教程中有更详细地介绍(可联系小运营获取快速上手教程),欢迎大家来试用并且提出你们宝贵的意见。


总之,活用 MMGeneration

不仅仅能生成对象哦!

心动不如行动,快一起试一下吧!

640.png


下 *N 期预告

# MMGeneration # 真 · 生成对象


MMGeneration 开发者说

如果 Github 上 OpenMMLab 的 Star

增加超过 1314 就安排

兄弟们,机会来了靠自己了!


文章来源:公众号【OpenMMLab】

2021-08-16 19:00


目录
相关文章
|
7月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
1188 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
209 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
4月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
309 9
|
6月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
294 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
201 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
5月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
2327 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
7月前
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
852 17
|
8月前
|
存储 自然语言处理 PyTorch
从零开始用Pytorch实现LLaMA 4的混合专家(MoE)模型
近期发布的LLaMA 4模型引入混合专家(MoE)架构,以提升效率与性能。尽管社区对其实际表现存在讨论,但MoE作为重要设计范式再次受到关注。本文通过Pytorch从零实现简化版LLaMA 4 MoE模型,涵盖数据准备、分词、模型构建(含词元嵌入、RoPE、RMSNorm、多头注意力及MoE层)到训练与文本生成全流程。关键点包括MoE层实现(路由器、专家与共享专家)、RoPE处理位置信息及RMSNorm归一化。虽规模小于实际LLaMA 4,但清晰展示MoE核心机制:动态路由与稀疏激活专家,在控制计算成本的同时提升性能。完整代码见链接,基于FareedKhan-dev的Github代码修改而成。
345 9
从零开始用Pytorch实现LLaMA 4的混合专家(MoE)模型
|
7月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。

热门文章

最新文章

推荐镜像

更多