ViTGAN:用视觉Transformer训练生成性对抗网络 Training GANs with Vision Transformers

简介: ViTGAN是加州大学圣迭戈分校与 Google Research提出的一种用视觉Transformer来训练GAN的模型。该论文已被NIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年10月。论文地址:https://arxiv.org/abs/2107.04589代码地址:https://github.com/teodorToshkov/ViTGAN-pytorch本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。

> @[TOC](目录)


ViTGAN是加州大学圣迭戈分校与 Google Research提出的一种用视觉Transformer来训练GAN的模型。该论文已被NIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年10月。


论文地址:https://arxiv.org/abs/2107.04589

代码地址:https://github.com/teodorToshkov/ViTGAN-pytorch


本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。


# 一、原文摘要


> 最近,Vision Transformer(VIT)在图像识别方面表现出了竞争性的性能,同时需要更少的视觉特定感应偏差。在本文中,我们研究这种观察是否可以扩展到图像生成。为此,我们将ViT体系结构集成到生成性对抗网络(GAN)中。我们观察到,现有的GANs正则化方法与自我注意的交互作用很差,导致训练期间严重不稳定。为了解决这个问题,我们引入了新的正则化技术,用ViTs训练GANs。根据经验,我们的方法名为ViTGAN,在CIFAR-10、CelebA和LSUN卧室数据集上实现了与基于CNN的最先进StyleGAN2相当的性能


# 二、为什么提出ViTGAN?


2021年论文《An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR,2021.》**视觉Transformer(简称ViT)** 在图像识别方面表现出具有竞争力的性能,在ViTGAN中,主要研究的是是否能用Transformer来进行图像生成。研究议题是:不使用卷积或池化,能否使用视觉 Transformer 来完成图像生成任务?更具体而言:能否使用 ViT 来训练生成对抗网络(GAN)并使之达到与已被广泛研究过的基于 CNN 的 GAN 相媲美的质量?


**使用原始Vit来组建GAN时,训练非常不稳定**,而且在鉴别器训练的后期,对抗性训练经常受到高方差梯度的阻碍,此外,**传统的正则化方法,如梯度惩罚,谱归一化无法解决这个不稳定性问题**。针对这些问题,为了实现训练动态的稳定以及促进基于 ViT 的 GAN 的收敛,这篇论文提出了多项必需的修改。


以往的生成性Transformer将图像生成建模为一个自回归序列学习问题。与ViTGAN比较接近的工作就是TransGAN,TransGAN提出多任务协同训练和局部初始化以获得更好的训练,但却忽略了训练稳定性的关键技术,在很大程度上落后于领先的卷积GAN模型。


# 三、Vision Transformer

Vision Transformer是一种纯Transformer架构,用于对一系列图像块进行操作的图像分类。


在ViT中,$\mathbf{x} \in \mathbb{R}^{H \times W \times C}$被展平为一系列patches,每个patch为$\mathbf{x}_{p} \in \mathbb{R}^{L \times\left(P^{2} \cdot C\right)}$,其中$L=\frac{H \times W}{P^{2}}$,P×P×C是每个图像块的尺寸。


图像序列中引入一个可学习的分类嵌入$x_{class}$,已经位置嵌入$E_{po}$,形成patch嵌入$h_0$:

$\begin{aligned}

\mathbf{h}_{0} &=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{L} \mathbf{E}\right]+\mathbf{E}_{p o s}, & & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{p o s} \in \mathbb{R}^{(L+1) \times D} \\

\mathbf{h}_{\ell}^{\prime} &=\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{h}_{\ell-1}\right)\right)+\mathbf{h}_{\ell-1}, & & \ell=1, \ldots, L \\

\mathbf{h}_{\ell} &=\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{h}_{\ell}^{\prime}\right)\right)+\mathbf{h}_{\ell}^{\prime}, & & \ell=1, \ldots, L \\

\mathbf{y} &=\operatorname{LN}\left(\mathbf{h}_{L}^{0}\right) & &

\end{aligned}$


其中,MSA是多头自注意力(MSA):

$\operatorname{MSA}(\mathbf{X})=\operatorname{concat}_{h=1}^{H}\left[\operatorname{Attention}_{h}(\mathbf{X})\right] \mathbf{W}+\mathbf{b}$


单个注意力头的计算公式为:

$\operatorname{Attention}_{h}(\mathbf{X})=\operatorname{softmax}\left(\frac{\mathbf{Q} \mathbf{K}^{\top}}{\sqrt{d_{h}}}\right) \mathbf{V}$


# 四、ViTGAN

ViTGAN的基础结构如下,一个ViT组成了生成器,一个ViT组成了鉴别器:

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/18f63b3b1c834359aada2ea5d8dad792.png)

直接使用ViT会使训练不稳定,于是作者引入了(1)生成器结构优化;(2)鉴别器正则化


## 4.1、生成器

因为ViT(Vision Transformer)原来是对图片进行分类,预测标签,而**ViTGAN想达到的是让其能在空间区域生成像素**。


作者为此比较了三种Transformer做生成器的架构,输入为 由MLP从高斯噪声向量z 导出的潜在向量w

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/2dc2771bccc34795a773998ebce5490a.png)


(A):在每个位置嵌入中加入中间潜在嵌入w,然后经过Transformer和一层MLP分别指导不同patch块的像素生成

(B):只在序列最开始加入中间潜在嵌入w

(C):将归一化Norm层替换为自调制层(SLN),该自调制层如下所示,其使用从w中学到的仿射变换(A)对norm层进行调整。

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/e532b31f845b40499a6f016c4213c22e.png)

作者使用的是C,下面将对其结构和原理进行剖析:

### 4.1.1、生成器设计

要用Transformer生成像素值,就要使用一个线性投影层E,其将输入的D维嵌入映射到每个大小为P×P×C的patch当中,然后每个patch(一共(H*W)/P² 个patch)最终重组成一整张图像。


于是基于ViT设计的生成器由两个组件组成:(1)Transformer块;(2)输出映射层。


**如下图所示,Transformer块作为编码器,主体结构如下右所示,将Embedding经过Norm、多头注意力层、Norm和MLP后输出到输出映射层,输出映射层主要是一个MLP。**

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/b12d56da74f9426387b215a0db37320d.png)

计算原理如下:

$\begin{aligned}

\mathbf{h}_{0} &=\mathbf{E}_{\text {pos }}, & & \mathbf{E}_{\text {pos }} \in \mathbb{R}^{L \times D}, \\

\mathbf{h}_{\ell}^{\prime} &=\operatorname{MSA}\left(\operatorname{SLN}\left(\mathbf{h}_{\ell-1}, \mathbf{w}\right)\right)+\mathbf{h}_{\ell-1}, & & \ell=1, \ldots, L, \mathbf{w} \in \mathbb{R}^{D} \\

\mathbf{h}_{\ell} &=\operatorname{MLP}\left(\operatorname{SLN}\left(\mathbf{h}_{\ell}^{\prime}, \mathbf{w}\right)\right)+\mathbf{h}_{\ell}^{\prime}, & & \ell=1, \ldots, L \\

\mathbf{y} &=\operatorname{SLN}\left(\mathbf{h}_{L}, \mathbf{w}\right)=\left[\mathbf{y}^{1}, \cdots, \mathbf{y}^{L}\right] & \mathbf{y}^{1}, \ldots, \mathbf{y}^{L} \in \mathbb{R}^{D} \\

\mathbf{x} &=\left[\mathbf{x}_{p}^{1}, \cdots, \mathbf{x}_{p}^{L}\right]=\left[f_{\theta}\left(\mathbf{E}_{f o u}, \mathbf{y}^{1}\right), \ldots, f_{\theta}\left(\mathbf{E}_{f o u}, \mathbf{y}^{L}\right)\right] & & \mathbf{x}_{p}^{i} \in \mathbb{R}^{P^{2} \times C}, \mathbf{x} \in \mathbb{R}^{H \times W \times C}

\end{aligned}$


### 4.1.2、 自调制层归一化层(SLN)

自调制是指:不使用噪声z作为输入,而是使用z来调制LayerNorm运算:

$\operatorname{SLN}\left(\mathbf{h}_{\ell}, \mathbf{w}\right)=\operatorname{SLN}\left(\mathbf{h}_{\ell}, \operatorname{MLP}(\mathbf{z})\right)=\gamma_{\ell}(\mathbf{w}) \odot \frac{\mathbf{h}_{\ell}-\boldsymbol{\mu}}{\boldsymbol{\sigma}}+\beta_{\ell}(\mathbf{w})$


其中µ和σ表示的是总输入的均值和方差,γl和βl表示的是计算由z导出的潜在向量控制的自适应归一化参数。


### 4.1.3、隐式神经表征生成patch片图像

使用隐式神经表示学习从patch embedding $y^i$到patch pixel$x^i_p$的映射。当与傅里叶特征或正弦激活函数结合时,隐式表示可以将生成样本的空间限制为平滑变化的自然信号的空间,在式子中表示为$E_{fou}$是空间位置的傅里叶编码,$f_θ$是两层MLP。


## 4.2、鉴别器设计

鉴别器暂略,详情可以看原文


# 五、实验

## 5.1、数据集

CIFAR-10 、LSUN bedroom、CelebA

## 5.2、实验结果

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/64886508a93b4a42830643c84fcc5da5.png)

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/7a6a806d27af41e3a09f484c035ba616.png)

## 5.3、消融实验

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/4bca99a92db24d29a19c7c717dcf7593.png)

![在这里插入图片描述](https://ucc.alicdn.com/images/user-upload-01/12083f0b2a1e42e69617cd17109207a9.png)


# 六、总结

1. 在GANs中利用了vision transformer,并提出了确保其训练稳定性和改进其收敛性的关键技术;

2. 经过丰富的实验证明其与基于CNN的最先进的GANs性能相当。


# 最后

💖 个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向


📝 个人主页:[中杯可乐多加冰](https://blog.csdn.net/air__Heaven)


🔥  **限时免费**订阅:[文本生成图像T2I专栏](https://blog.csdn.net/air__heaven/category_11407863.html)



🎉 支持我:点赞👍+收藏⭐️+留言📝

相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
139 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
2月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
67 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
2月前
|
机器学习/深度学习 人工智能
类人神经网络再进一步!DeepMind最新50页论文提出AligNet框架:用层次化视觉概念对齐人类
【10月更文挑战第18天】这篇论文提出了一种名为AligNet的框架,旨在通过将人类知识注入神经网络来解决其与人类认知的不匹配问题。AligNet通过训练教师模型模仿人类判断,并将人类化的结构和知识转移至预训练的视觉模型中,从而提高模型在多种任务上的泛化能力和稳健性。实验结果表明,人类对齐的模型在相似性任务和出分布情况下表现更佳。
69 3
|
2月前
|
机器学习/深度学习 人工智能 编解码
探索生成对抗网络(GANs):人工智能领域的革新力量
【10月更文挑战第14天】探索生成对抗网络(GANs):人工智能领域的革新力量
87 1
|
14天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
97 30
|
23天前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
48 8
|
22天前
|
机器学习/深度学习 算法
生成对抗网络(Generative Adversarial Networks,简称GANs)
生成对抗网络(GANs)由Ian Goodfellow等人于2014年提出,是一种通过生成器和判别器的对抗训练生成逼真数据样本的深度学习模型。生成器创造数据,判别器评估真实性,两者相互竞争优化,广泛应用于图像生成、数据增强等领域。
|
2月前
|
机器学习/深度学习 编解码 人工智能
技术前沿探索:生成对抗网络(GANs)的革新之路
【10月更文挑战第14天】技术前沿探索:生成对抗网络(GANs)的革新之路
43 2
|
2月前
|
机器学习/深度学习 编解码 人工智能
技术前沿探索:生成对抗网络(GANs)的革新之路
【10月更文挑战第14天】技术前沿探索:生成对抗网络(GANs)的革新之路
59 1
|
2月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
63 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)