论文精读 TransGAN:两个纯粹的Transformer可以组成一个强大的GAN(TransGAN:Two Pure Transformers Can Make One Strong GAN)

简介: TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12月。该文章旨在仅使用Transformer网络设计GAN。Can we build a strong GAN completely free of convolutions?论文地址:https://

@TOC

TransGAN是UT-Austin、加州大学、 IBM研究院的华人博士生构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。该论文已被NeruIPS(Conference and Workshop on Neural Information Processing Systems,计算机人工智能领域A类会议)录用,文章发表于2021年12月。

该文章旨在仅使用Transformer网络设计GAN。Can we build a strong GAN completely free of convolutions?

论文地址:https://arxiv.org/abs/2102.07074

代码地址:https://github.com/VITA-Group/TransGAN

本博客是精读这篇论文的报告,包含一些个人理解、知识拓展和总结。

一、原文摘要

最近,人们对Transformer产生了爆炸性的兴趣,这表明Transformer有可能成为计算机视觉任务(如分类、检测和分割)的强大“通用”模型。虽然这些尝试主要研究区分模型,但我们探索了一些更为困难的视觉任务,例如生成性对抗网络(GAN)。

我们的目标是进行第一次试验性研究,仅使用纯Transformer架构,构建完全没有卷积的GAN,我们的vanilla-GAN架构被称为TransGAN,包括一个基于内存友好的Transformer的生成器,该生成器可逐渐提高特征分辨率,并相应地包含一个多尺度鉴别器,可同时捕获语义上下文和低级纹理。

在此基础上,我们引入了新的网格自关注模块,以进一步缓解内存瓶颈,从而将TransGAN扩展到高分辨率生成。我们还开发了一个独特的训练配方,包括一系列可以缓解TransGAN训练不稳定性问题的技术,如数据增强、修改的标准化和相对位置编码

与目前最先进的使用卷积主干的GANs相比,我们的架构实现了极具竞争力的性能。TransGAN能够生成具有高保真度和合理纹理细节的各种视觉示例。此外,通过可视化训练动态,我们深入研究了基于Transformer的生成模型,以了解它们的行为与卷积模型的区别。

二、介绍

而本文主要创新点如下:

  1. 新颖的结构设计:第一次使用纯粹的Transformer来构建无卷积的GAN。TransGAN定制了一个便于记忆的生成器和多尺度鉴别器,并进一步配备了一种新的网格自我注意机制。这些体系结构组件经过深思熟虑的设计,以平衡内存效率、全局特征统计和局部细节与空间差异。
  2. 新的训练方法:研究了一些技术来更好地训练TransGAN,包括利用数据增强、修改层规范化,以及对生成器和鉴别器采用相对位置编码。且进行了广泛的消融研究和讨论。
  3. 与当前最先进的GANs相比,TransGAN实现了极具竞争力的性能

三、为什么提出TransGAN?

  1. 一方面传统GAN存在模式崩溃问题,训练不稳定,在这几年里为了致力于稳定GAN训练,研究者们引入了各种正规化术语,更好的损失函数,以及各种变体训练方法,但是从2015的DC-GAN使用CNN架构来扩展GAN以来,每一个成功的GAN都依赖于基于CNN的生成器和鉴别器。
  2. 另一方面传统GAN都是基于卷积的,缺点是只有局部感受野,深层次会丢失细节

最初的transformer是为NLP设计的,在NLP中,多头自我注意层和前向反馈网络层层被堆叠起来,以捕捉单词之间的长期相关性,最近,Transformer在图像生成方面也有进展,通过替换CNN的某些组件,将Transformer模块结合到图像生成模型中,然而其CNN的整体架构仍然存在(包括用于发生器的CNN编码器/解码器,以及完全基于CNN的鉴别器)。

四、主要框架

在这里插入图片描述

4.1、生成器

如果以逐个像素作为输入,32*32的低分辨率图像也会导致1024长度的序列,与单词序列相比,数据指数级增长,如果再加入注意力,则参数爆炸式增长。于是作者的策略是分阶段迭代提高分辨率,即增加输入序列同时逐渐降低维数。

在这里插入图片描述

  1. 输入为随机噪声,首先通过多层感知器(MLP)形成一段长序列。
  2. 然后序列经过transformer的encoder块,输出长序列。
  3. 上采样模块包括重塑、上采样和重塑阶段。其首先将该长序列重塑为8×8×C(将1D的序列转换成了2D的图像特征image.png),然后使用双三次插值的方法进行上采样,使之在维度不减的情况下提高采样分辨率,变成16×16×C的图像特征。然后又一次重塑成1D的序列。
  4. 将重塑后的1D的序列再次经过步骤2、步骤3,生成32×32×C的图像特征,重塑成1D序列。又一次经过步骤2、步骤3生成64×64×C的图像特征,重塑成1D序列。然后再次经过步骤2,但不急着做步骤3。
  5. 此时的上采样模块进行了改进,与3不同的是双三次插值改成了pixel shuffle模块,将4的长序列输入重塑为64×64×C的图像特征(image.png),使用pixel shuffle进行上采样,变为image.png的图像特征(image.png),然后将其又重塑为1D的序列。
  6. 后面即重复一次transformer的encoder块,然后将生成的长序列重塑为image.png,再次经过一次pixel shuffle,从image.png特征变为image.png特征,然后最后进行一次线性加权,得到256×256×3的图像。

4.2、鉴别器

鉴别器的任务是区分真假图像,也就是分类任务。作者设计了一个多尺度的鉴别器,在不同的阶段以不同大小的面片作为输入。(因为三种不同的序列能够同时提取语义结构和纹理细节。)

在这里插入图片描述

  1. 首先将图像分割成同样大小的P×P、2P×2P、4P×4P个块,作为不同的尺度。
  2. 图像第一个尺度的大小为image.png,首先通过线性加权将其转换成image.png,同样第二个尺度image.png转换成image.png,第三个尺度image.png转换成image.png。第一个尺度转成的token是为了给第一个的transformer块作为输入,第二个和第三个分别连接到第二、三阶段的token后(捕捉更多的纹理信息)。
  3. 与生成器反过来类似,我们首先将token输入transformer块,然后将输出的一维向量重塑为二维特征图,并在每个阶段之间采用平均池层对特征图分辨率进行降采样。
  4. 在这些块的末尾,在1D序列的开始处附加[cls]标记,以输出真/假预测。

4.3、Self-Attention的一种变体:Grid Self-Attention

self-attn

Self-attention虽然使生成器能够捕获全局对应关系,但在建模高的分辨率时,会出现超长序列,会极大影响效率,于是作者提出了Grid Self-Attention:

在这里插入图片描述

Grid Self-Attention将全尺寸特征映射划分为几个非重叠网格,网格内进行Self-attention(分成多个块,块内做标准的self-attention,然后将每个块相连)。

Grid Self-Attention在TransGAN中,只被运用在64×64以上分辨率以减少消耗,64以下的仍然采用标准的self-attention。这样的做法从战略上平衡局部细节和全局效率。

五、改进性策略

5.1、数据增强

对比卷积来说,Transforme是更需要数据的,不同类型的强大数据增强可以为Transformer提供高效的训练。

作者从三个角度进行了数据增强:Translation, Cutout, Color,让TransGAN的性能有了惊人的提高

Translation是做些许偏移,Cutout在图像上加一些纯白或者纯黑的像素点,Color就是改变图像的对比度、饱和度。

5.2、相对位置编码

虽然经典的transformer已经有相对位置编码,但是其发挥出的作用不够明显。

作者将image.png改为image.png,其中E取自矩阵M,并作为残差项添加(M是同时考虑H轴和W轴,用表示相对位置的参数化矩阵image.png

相对位置编码学习了内容之间更强的“关系”,能够极大提升性能。

5.3、修正后的归一化

归一化层(Normalization )有助于稳定深层神经网络的深层学习训练,效果显著,原版标准归一化使用的是layer normalization,作者提出了一种image.png,其中image.png,X和Y表示缩放层前后的标记,C代表嵌入维度。(类似于AlexNet中曾经使用的局部响应规范化)

六、实验

6.1、数据集

CIFAR-10、STL10和CelebA数据集。

6.2、实验设置

遵循WGAN的设置,并使用WGAN-GP损失, 生成器的batch大小为128,鉴别器的batch大小为64,选择DiffAug作为培训过程中的基本增强策略。评价指标使用IS和FID。

6.3、实验结果

实验细节见论文

在这里插入图片描述

在这里插入图片描述

6.4、消融实验

实验细节见论文

在这里插入图片描述

6.5、实验消耗

实验细节见论文

在这里插入图片描述

相关文章
|
机器学习/深度学习 编解码 自然语言处理
Vision Transformer 必读系列之图像分类综述(二): Attention-based(上)
Transformer 结构是 Google 在 2017 年为解决机器翻译任务(例如英文翻译为中文)而提出,从题目中可以看出主要是靠 Attention 注意力机制,其最大特点是抛弃了传统的 CNN 和 RNN,整个网络结构完全是由 Attention 机制组成。为此需要先解释何为注意力机制,然后再分析模型结构。
691 0
Vision Transformer 必读系列之图像分类综述(二): Attention-based(上)
|
3月前
|
机器学习/深度学习 JavaScript 算法
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
深度学习500问——Chapter07:生成对抗网络(GAN)(1)
81 3
|
3月前
|
机器学习/深度学习 编解码 自然语言处理
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
深度学习500问——Chapter07:生成对抗网络(GAN)(3)
65 0
|
机器学习/深度学习 编解码 自然语言处理
Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation论文解读
在过去的几年中,卷积神经网络(CNN)在医学图像分析方面取得了里程碑式的进展。特别是基于U型结构和跳跃连接的深度神经网络在各种医学图像任务中得到了广泛的应用。
597 0
|
机器学习/深度学习 算法 数据可视化
深度学习论文阅读目标检测篇(一):R-CNN《Rich feature hierarchies for accurate object detection and semantic...》
 过去几年,在经典数据集PASCAL上,物体检测的效果已经达到 一个稳定水平。效果最好的方法是融合了多种低维图像特征和高维上 下文环境的复杂集成系统。在这篇论文里,我们提出了一种简单并且 可扩展的检测算法,可以在VOC2012最好结果的基础上将mAP值提 高30%以上——达到了53.3%。
143 0
深度学习论文阅读目标检测篇(一):R-CNN《Rich feature hierarchies for accurate object detection and semantic...》
|
机器学习/深度学习 人工智能 算法
CVPR‘2023 | Cross-modal Adaptation: 基于 CLIP 的微调新范式
CVPR‘2023 | Cross-modal Adaptation: 基于 CLIP 的微调新范式
1181 0
|
机器学习/深度学习 数据挖掘 Go
深度学习论文阅读图像分类篇(五):ResNet《Deep Residual Learning for Image Recognition》
更深的神经网络更难训练。我们提出了一种残差学习框架来减轻 网络训练,这些网络比以前使用的网络更深。我们明确地将层变为学 习关于层输入的残差函数,而不是学习未参考的函数。我们提供了全 面的经验证据说明这些残差网络很容易优化,并可以显著增加深度来 提高准确性。在 ImageNet 数据集上我们评估了深度高达 152 层的残 差网络——比 VGG[40]深 8 倍但仍具有较低的复杂度。这些残差网络 的集合在 ImageNet 测试集上取得了 3.57%的错误率。这个结果在 ILSVRC 2015 分类任务上赢得了第一名。我们也在 CIFAR-10 上分析 了 100 层和 1000 层的残差网络。
205 0
|
缓存 算法 PyTorch
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
2742 0
YOLOv5的Tricks | 【Trick12】YOLOv5使用的数据增强方法汇总
|
机器学习/深度学习 算法 计算机视觉
YOLOv5的Tricks | 【Trick2】目标检测中进行多模型推理预测(Model Ensemble)
在学习yolov5代码的时候,发现experimental.py文件中有一个很亮眼的模块:Ensemble。接触过机器学习的可能了解到,机器学习的代表性算法是随机森林这种,使用多个模型来并行推理,然后归纳他们的中值或者是平均值来最为整个模型的最后预测结构,没想到的是目标检测中也可以使用,叹为观止。下面就对其进行详细介绍:
1309 1
|
机器学习/深度学习 编解码 自然语言处理
论文阅读笔记 | Transformer系列——Focal Transformer
论文阅读笔记 | Transformer系列——Focal Transformer
181 0
论文阅读笔记 | Transformer系列——Focal Transformer