@TOC
TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12月。
该文章旨在仅使用Transformer网络设计GAN。Can we build a strong GAN completely free of convolutions?
论文地址:https://arxiv.org/abs/2102.07074
代码地址:https://github.com/VITA-Group/TransGAN
本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。
一、原文摘要
最近,人们对Transformer产生了爆炸性的兴趣,这表明Transformer有可能成为计算机视觉任务(如分类、检测和分割)的强大“通用”模型。虽然这些尝试主要研究区分模型,但我们探索了一些更为困难的视觉任务,例如生成性对抗网络(GAN)。
我们的目标是进行第一次试验性研究,仅使用纯Transformer架构,构建完全没有卷积的GAN,我们的vanilla-GAN架构被称为TransGAN,包括一个基于内存友好的Transformer的生成器,该生成器可逐渐提高特征分辨率,并相应地包含一个多尺度鉴别器,可同时捕获语义上下文和低级纹理。
在此基础上,我们引入了新的网格自关注模块,以进一步缓解内存瓶颈,从而将TransGAN扩展到高分辨率生成。我们还开发了一个独特的训练配方,包括一系列可以缓解TransGAN训练不稳定性问题的技术,如数据增强、修改的标准化和相对位置编码。
与目前最先进的使用卷积主干的GANs相比,我们的架构实现了极具竞争力的性能。TransGAN能够生成具有高保真度和合理纹理细节的各种视觉示例。此外,通过可视化训练动态,我们深入研究了基于Transformer的生成模型,以了解它们的行为与卷积模型的区别。
二、介绍
而本文主要创新点如下:
- 新颖的结构设计:第一次使用纯粹的Transformer来构建无卷积的GAN。TransGAN定制了一个便于记忆的生成器和多尺度鉴别器,并进一步配备了一种新的网格自我注意机制。这些体系结构组件经过深思熟虑的设计,以平衡内存效率、全局特征统计和局部细节与空间差异。
- 新的训练方法:研究了一些技术来更好地训练TransGAN,包括利用数据增强、修改层规范化,以及对生成器和鉴别器采用相对位置编码。且进行了广泛的消融研究和讨论。
- 与当前最先进的GANs相比,TransGAN实现了极具竞争力的性能。
三、为什么提出TransGAN?
- 一方面传统GAN存在模式崩溃问题,训练不稳定,在这几年里为了致力于稳定GAN训练,研究者们引入了各种正规化术语,更好的损失函数,以及各种变体训练方法,但是从2015的DC-GAN使用CNN架构来扩展GAN以来,每一个成功的GAN都依赖于基于CNN的生成器和鉴别器。
- 另一方面传统GAN都是基于卷积的,缺点是只有局部感受野,深层次会丢失细节。
最初的transformer是为NLP设计的,在NLP中,多头自我注意层和前向反馈网络层层被堆叠起来,以捕捉单词之间的长期相关性,最近,Transformer在图像生成方面也有进展,通过替换CNN的某些组件,将Transformer模块结合到图像生成模型中,然而其CNN的整体架构仍然存在(包括用于发生器的CNN编码器/解码器,以及完全基于CNN的鉴别器)。
四、主要框架
在这里插入图片描述
4.1、生成器
如果以逐个像素作为输入,32*32的低分辨率图像也会导致1024长度的序列,与单词序列相比,数据指数级增长,如果再加入注意力,则参数爆炸式增长。于是作者的策略是分阶段迭代提高分辨率,即增加输入序列同时逐渐降低维数。
在这里插入图片描述
- 输入为随机噪声,首先通过多层感知器(MLP)形成一段长序列。
- 然后序列经过transformer的encoder块,输出长序列。
- 上采样模块包括重塑、上采样和重塑阶段。其首先将该长序列重塑为8×8×C(将1D的序列转换成了2D的图像特征),然后使用双三次插值的方法进行上采样,使之在维度不减的情况下提高采样分辨率,变成16×16×C的图像特征。然后又一次重塑成1D的序列。
- 将重塑后的1D的序列再次经过步骤2、步骤3,生成32×32×C的图像特征,重塑成1D序列。又一次经过步骤2、步骤3生成64×64×C的图像特征,重塑成1D序列。然后再次经过步骤2,但不急着做步骤3。
- 此时的上采样模块进行了改进,与3不同的是双三次插值改成了pixel shuffle模块,将4的长序列输入重塑为64×64×C的图像特征(),使用pixel shuffle进行上采样,变为的图像特征(),然后将其又重塑为1D的序列。
- 后面即重复一次transformer的encoder块,然后将生成的长序列重塑为,再次经过一次pixel shuffle,从特征变为特征,然后最后进行一次线性加权,得到256×256×3的图像。
4.2、鉴别器
鉴别器的任务是区分真假图像,也就是分类任务。作者设计了一个多尺度的鉴别器,在不同的阶段以不同大小的面片作为输入。(因为三种不同的序列能够同时提取语义结构和纹理细节。)
在这里插入图片描述
- 首先将图像分割成同样大小的P×P、2P×2P、4P×4P个块,作为不同的尺度。
- 图像第一个尺度的大小为,首先通过线性加权将其转换成,同样第二个尺度转换成,第三个尺度转换成。第一个尺度转成的token是为了给第一个的transformer块作为输入,第二个和第三个分别连接到第二、三阶段的token后(捕捉更多的纹理信息)。
- 与生成器反过来类似,我们首先将token输入transformer块,然后将输出的一维向量重塑为二维特征图,并在每个阶段之间采用平均池层对特征图分辨率进行降采样。
- 在这些块的末尾,在1D序列的开始处附加[cls]标记,以输出真/假预测。
4.3、Self-Attention的一种变体:Grid Self-Attention
self-attn
Self-attention虽然使生成器能够捕获全局对应关系,但在建模高的分辨率时,会出现超长序列,会极大影响效率,于是作者提出了Grid Self-Attention:
在这里插入图片描述
Grid Self-Attention将全尺寸特征映射划分为几个非重叠网格,网格内进行Self-attention(分成多个块,块内做标准的self-attention,然后将每个块相连)。
Grid Self-Attention在TransGAN中,只被运用在64×64以上分辨率以减少消耗,64以下的仍然采用标准的self-attention。这样的做法从战略上平衡局部细节和全局效率。
五、改进性策略
5.1、数据增强
对比卷积来说,Transforme是更需要数据的,不同类型的强大数据增强可以为Transformer提供高效的训练。
作者从三个角度进行了数据增强:Translation, Cutout, Color,让TransGAN的性能有了惊人的提高。
Translation是做些许偏移,Cutout在图像上加一些纯白或者纯黑的像素点,Color就是改变图像的对比度、饱和度。
5.2、相对位置编码
虽然经典的transformer已经有相对位置编码,但是其发挥出的作用不够明显。
作者将改为,其中E取自矩阵M,并作为残差项添加(M是同时考虑H轴和W轴,用表示相对位置的参数化矩阵)
相对位置编码学习了内容之间更强的“关系”,能够极大提升性能。
5.3、修正后的归一化
归一化层(Normalization )有助于稳定深层神经网络的深层学习训练,效果显著,原版标准归一化使用的是layer normalization,作者提出了一种,其中,X和Y表示缩放层前后的标记,C代表嵌入维度。(类似于AlexNet中曾经使用的局部响应规范化)
六、实验
6.1、数据集
CIFAR-10、STL10和CelebA数据集。
6.2、实验设置
遵循WGAN的设置,并使用WGAN-GP损失, 生成器的batch大小为128,鉴别器的batch大小为64,选择DiffAug作为培训过程中的基本增强策略。评价指标使用IS和FID。
6.3、实验结果
实验细节见论文
在这里插入图片描述
在这里插入图片描述
6.4、消融实验
实验细节见论文
在这里插入图片描述
6.5、实验消耗
实验细节见论文
在这里插入图片描述