更好的性能!新型自监督学习方法 CAE 了解一下

简介: 1)MIM 方法中,网络结构的哪个部分是学习表征的,哪个部分是解决 pretext task?2)为什么之前典型的 contrastive learning 方法,在下游任务(例如检测、分割)上只能取得跟 supervised pretraining 类似的性能?3)MIM 方法为什么优于目前的 contrastive learning 方法?

来自北京大学、香港大学和百度的研究者们近日提出了一种名为 CAE 的新型 MIM 方法。该方法通过对 “表征学习” 和 “解决前置任务(pretext task)” 这两个功能做完全分离,使得 encoder 学习到更好的表征,从而在下游任务上实现了更好的泛化性能。


今天我们有幸邀请到了研究者之一的 Xiaokang Chen,他将为我们带来该方法的深度解读。

640.gif




前言



Mask Image Modeling (MIM) 方法,在 NLP 领域(例如 BERT)得到了广泛的应用。随着 ViT 的提出和发展,人们也尝试将 MIM 应用到视觉领域并取得了一定进展。在此之前,视觉自监督算法主要沿着 contrastive learning 的思路去设计,而 MIM 无疑打开了新的大门。


我们最近的工作 “Context Autoencoder for Self-Supervised Representation Learning”,提出了一种新的 MIM 方法 CAE,通过对 “表征学习” 和 “解决 pretext task” 这两个功能做完全分离,使得 encoder 学习到更好的表征,从而在下游任务实现了更好的泛化性能。我们尝试回答如下几个问题:


1)MIM 方法中,网络结构的哪个部分是学习表征的,哪个部分是解决 pretext task?


2)为什么之前典型的 contrastive learning 方法,在下游任务(例如检测、分割)上只能取得跟 supervised pretraining 类似的性能?


3)MIM 方法为什么优于目前的 contrastive learning 方法?

640.png


1. 背景



MIM 是一种自监督表征学习算法。它的主要思路是:对输入图像进行分块和随机掩码操作,然后对掩码区域做一些预测。预测的目标可以是 Token ID(BEiT),也可以是 RGB 的值(MAE)。通过 MIM,我们希望 encoder 能学习到一个好的表征,从而在下游任务取得良好的泛化性能。


近期 MIM 有两个代表性工作:BEiT 和 MAE。


BEiT 使用一个 encoder 做两件事:(1) 学习一个好的图像表征;(2) 解决 pretext task:预测 masked patch 的 Token ID。encoder 的潜力并没有完全被挖掘,只有部分被用来学习表征。


MAE 使用了 encoder-decoder 架构,encoder 负责对 visible patch 进行表征学习,decoder 将 visible 和 masked patch 的表征(masked patch 使用一个可学习的向量)作为输入,预测 masked patch 的 RGB 值。但是,MAE 在 decoder 中也会对 visible patch 的表征进行改变。与此同时,MAE decoder 利用改变后的可见区域的表征去预测遮挡区域的表征,实际上 decoder 也负责了一部分表征学习的功能。然而,在下游任务中,只有 encoder 中学到的信息能被拿来用,那么即使在 decoder 中进一步学到了更好的表征,也无法利用到下游任务中。


以上两种方法,都没有充分挖掘 encoder 的潜力,限制了预训练学习到的表征质量。


2. Context Autoencoder (CAE)



CAE 设计的核心思想是对 “表征学习” 和 “解决 pretext task” 这两个功能做分离。我们希望在预训练时,表征学习的任务只交给 encoder,而 decoder 只负责解决 pretext task。这样我们希望从 encoder 出来的表征就是非常好的,而不需要额外的模块对表征进一步 refine(例如 MAE 中的 decoder),可以尽可能大地挖掘 encoder 的潜力。


如下图所示,CAE 包括 4 个部分:(1) encoder; (2) latent contextual regressor; (3) decoder; (4) alignment module。640.png

输入图像通过随机掩码被划分成 visible patch 和 masked patch 两个部分。具体来说:


Encoder


Encoder 是一个 ViT 模型,负责学习 visible patch 的表征   。


Latent contextual regressor

Latent contextual regressor 通过   image.png 预测 masked patch 的表征   image.png。Latent contextual regressor 由一系列 cross-attention module 组成,query 是 masked patch 的表征,key 和 value 是全部 patch 的表征。


在计算 query-key 相似度时,我们会引入每个 patch 对应的位置编码。在这个阶段,   image.png不断更新、变得更加准确,而  image.png  不会更新,对图像特征的提取这个任务完全交给 encoder。


Decoder

Decoder 只拿    image.png和 对应的位置编码作为输入,其目的是通过 image.png   预测 masked patch 的某些性质,比如由训练好的 tokenizer 产生的 Token ID,或者 RGB 的值。本文的实验 follow BEiT,使用 DALL-E tokenizer 对输入图像 token 化,得到 decoder 的目标。


Latent representation alignment


Latent representation alignment 是非常关键的一部分。虽然 visible patch 的表征在 encoder 之后就不会改变,但 latent contextual regressor 可能会”偷偷地“学习 masked patch 的表征,然后基于这样一个比 encoder 的输出更好的表征在 decoder 中进行预测。


如果是这样,那么 latent contextual regressor 也承担了一部分表征学习的功能,这与我们想要的“分离”是相悖的。于是我们通过对 image.png   添加约束,希望 latent contextual regressor 的输出和 encoder 的输出在同一编码空间中,这样表征学习的任务还是落到了 encoder 的身上。


我们将图像的 masked patch 也输入到 encoder,获得这部分的表征 image.png  。  image.png 将作为 image.png   学习的目标。计算    image.png的过程不会计算梯度。


损失函数



损失函数由两部分组成:(1) 对 decoder 预测的监督,使用 cross-entropy loss; (2) 对   image.png 和   image.png 的 align 的监督,使用 MSE loss。


3. 分析



3.1 CAE 关心每个 patch 的表征


CAE 基于 visible patch 的表征,从随机采样的 masked patch 做一些预测,这要求 CAE 关心每个 patch 的语义。这不同于典型的对比学习方法(例如 MoCo v3, SimCLR),这类方法只关心图像的全局语义,忽略了图像的细节和非主体区域(比如背景)。


3.2 Latent contextual regressor 的输出和 encoder 的输出在同一编码空间中


我们对 latent contextual regressor 的输出做了约束,希望它能和 encoder 的输出尽可能处于同一编码空间中。这样,decoder 会基于 encoder 学到的编码空间做预测,将对图像的特征提取的重任完全交到了 encoder 身上,驱使 encoder 学习到好的表征。


为了验证表征空间确实对齐了,我们用 RGB 值作为 decoder 目标 (考虑到 Token ID 难以可视化,这边使用 RGB),训练 CAE。在测试的时候,我们将全部 patch 输入到 encoder,然后跳过 latent contextual regressor,直接将 encoder 的输出送进 decoder,预测全部 patch 的 RGB 的值。下图展示了预测结果:第一行是原图,第二行是预测,我们发现:仅使用 encoder 和 decoder 就可以将图片重建出来。这说明 encoder 的输出和 latent contextual regressor 的输出属于同一编码空间。

640.jpg

如果训练时不做 alignment 约束,那么无法重建。如下图所示:输出都是乱码,说明 encoder 输出和 latent contextual regressor 的输出不在一个编码空间中。这使得 regressor 也承担了一部分表征学习的角色,使得 encoder 学到的表征质量有所欠缺,在消融实验部分也有验证。

640.jpg

3.3 CAE 学到的表征可以区分不同类别的 object/stuff


CAE 基于 visible patch 的表征,在 masked patch 区域做预测,这要求 CAE 对 visible patch 的内容有比较好的理解。


举例来说,我们看到一只狗的头部,我们可以预测出它的身体部分;我们看到一小片天空,我们也能预测出它的周围大概率也是一片天空。因此,我们认为:CAE 学到的表征可以区分不同类别的 object/stuff


为了验证这一点,我们从 ADE20K 数据集随机采样一些图片输入到 encoder。因为 ADE20K 提供了每个像素的类别标签(150 类),我们可以使用 t-SNE 对 encoder 输出的表征进行可视化。如下图所示:每个颜色代表一个类别,左图是 CAE,右图是随机初始化的 encoder。我们可以发现 CAE 可以有效区分不同类别的 object/stuff(因为是在 ImageNet-1K 进行预训练,所以区分地不够完美),而随机初始化的 encoder 无法做到这一点。

640.jpg


3.4 典型的 contrastive learning 为什么在下游任务只能取得跟 supervised pre-training 差不多的结果?


在 contrastive learning 中,random crop 是一个非常重要的数据增强策略。典型的 contrastive learning (比如 MoCo v3)希望最大化来自同一图像的 2 个不同 crop 之间的全局语义相似度,而最小化来自不同图像的 crop 之间的相似度。


这样为什么能奏效呢?我们首先分析 random crop 的性质。在 SimCLR 论文中提到,random crop 是 contrastive learning 方法中非常重要的数据增强策略。在 ImageNet-1K 数据集中,图像的主体物体大多处于图像的中心区域,而对图像进行 random crop,中心区域有很大的概率被囊括进去。例如下图展示的几个例子,几次 crop 基本都包括了图像的主体物体。

640.jpg

对同一图像的不同 crop 提取全局语义,实际上学到的是原始图像中主体物体的特征。正因如此,同一图像的不同 crop 之间才可能相似。在 supervised pre-training 中,受到图像分类标签的约束,网络学习到的也是图像主体区域的特征,这和 contrastive learning 学到的知识有很大的相似之处,因此在下游任务表现类似。


3.5 MIM 和 contrastive learning 的区别


MIM 方法(例如 CAE)基于 visible patch 的表征,对 masked patch 区域做预测。在做随机掩码时,图像的每个 patch(例如背景区域的 object/stuff)都有可能被考虑到,而不仅仅是图像的主体区域。为了做好 masked patch 的预测,CAE 会学好每个 patch 的表征。


我们对 CAE 以及 MoCo v3 的 attention map 做了可视化。如下图所示:第一行是原图,第二行是 MoCo v3,第三行是 CAE。


红色表示 attention value 更高,蓝色表示 attention value 更低。处于蓝色边界内部的区域,通过这样的原则筛选:将 attention value 从大到小排序后,保留累计和达到所有位置 attention value 总和的 50% 的部分。


我们可以看到:MoCo v3 的 attention map 主要在图像的主体区域有高响应,而 CAE 能考虑到几乎所有 patch。

640.jpg


实验



我们使用 ViT-small 和 ViT-base 在 ImageNet-1K 上进行实验。输入图像的分辨率是  image.png ,patch 大小是  image.png ,一张图会被划分成   image.png 个 patch。


4.1 Pretraining Evaluation


自监督学习广泛使用 linear probing 去评测预训练表征的好坏:将 encoder 的参数固定住,在之后加一个 linear classifier 进行图像分类。


我们认为 linear probing 不适合 MIM 方法,因为 MIM 方法通常会学到每个 patch 的表征,不仅包含主体物体的信息,还学到了背景等若干知识,这是多而杂的,不适合直接进行线性分类。


我们额外提出一种新的测试指标:attentive probing。我们在固定参数的 encoder 后加上一个简单的 cross-attention module(没有 FFN)和一个 linear classifier,通过注意力机制动态地选择适合做图像分类的信息。


我们对 attentive probing 阶段使用的 cross-attention module 做 attention map 可视化,发现可以关注到主体物体,如下图所示:

640.png

Finetune、linear probing、attentive probing 的结果见下表:

640.png

我们发现了一些有趣的现象:


(1) contrastive learning 方法(MoCo v3, DINO)的 linear probing 和 attentive probing 结果类似。这说明这类方法在预训练时已经将注意力放到了图像的主体物体上面,无需进一步动态筛选即可做好图像分类,这也与之前我们对 contrastive learning 的分析一致。


(2) MIM 方法(例如 CAE)的 attentive probing 相比 linear probing 有很大的提升。这说明 MIM 方法学到了每个 patch 的特征,而不仅仅是图像主体物体的,因此需要做一些筛选才利于图像分类。


4.2 消融实验


我们对 decoder 和 alignment module 进行消融实验,见下表。单加一个 decoder 能提高 attentive probing 的结果,但在下游任务(分割、检测)提升却不明显。使用 alignment module 之后能显著提升下游任务的性能,说明约束 encoder 的输出和 latent contextual regressor 的输出在同一编码空间非常重要,能提升 encoder 学到的表征质量。


640.png


4.3 语义分割


我们在 ADE20K 进行语义分割的实验,实验结果如下图所示。网络使用 UperNet,迭代次数为 160K,输入图像分辨率为 512*512,使用单尺度测试。


Contrastive learning 方法和 supervised pre-training(DeiT)的结果类似,而 CAE 能取得明显更好的结果。


跟其他 MIM 方法相比,CAE 的结果也更好,说明预训练阶段 encoder 被充分利用,学到的表征更好。使用 ViT-Large,我们能获得 54.7 mIoU,比 MAE(53.6)高 1.1 个点,比有监督学习 DeiT(49.9)高 4.8 个点。


非常有趣的是:CAE 和 DeiT 的 gap,在 large 模型下会比 base 模型更大,base 模型下它们的 gap 只有 3.2 个点。这说明了 CAE 能很好地泛化到大模型。

640.png


4.4 物体检测、实例分割


我们使用 Mask-RCNN 和 Cascade-RCNN 两种网络结构进行物体检测和实例分割的实验。我们使用 multi-scale training 训练 12 epoch(1x schedule),测试阶段仅使用单尺度测试。


如下面 2 张图片所示,实验现象和语义分割的类似:contrastive learning 方法和 supervised pre-training 方法结果类似且更差,CAE 的结果更好。使用 ViT-Large,在 1x schedule 下我们能获得 54.5 mAP。

640.png


总结



本文提出了 CAE,设计的核心有两点:


(1) 对 “表征学习” 和 “解决 pretext task” 这两个功能做完全分离;


(2) 在 visible patch 学习到的表征空间中对 masked patch 做预测。


以上两点都是为了驱使 encoder 学习更好的表征,从而在下游任务取得良好的泛化能力。除此之外,我们对 supervised pre-training、contrastive learning 和 MIM 方法进行了分析,认为 contrastive learning 和 supervised pre-training 主要关注图像的主体区域(例如 ImageNet-1K 标签集中的物体),而 MIM 会关注图像的全部 patch,更有利于下游任务。


文章来源:【OpenMMLab

2022-05-06 18:07

目录
相关文章
|
2天前
|
机器学习/深度学习 编解码 数据可视化
南开大学提出YOLO-MS | 超越YOLOv8与RTMDet,即插即用打破性能瓶颈
南开大学提出YOLO-MS | 超越YOLOv8与RTMDet,即插即用打破性能瓶颈
51 1
|
2天前
|
计算机视觉
模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息
模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息
43 0
|
2天前
|
机器学习/深度学习 计算机视觉
YOLOv8改进 | 细节涨点篇 | UNetv2提出的一种SDI多层次特征融合模块(分割高效涨点)
YOLOv8改进 | 细节涨点篇 | UNetv2提出的一种SDI多层次特征融合模块(分割高效涨点)
209 2
|
10月前
|
机器学习/深度学习
结合亲和力提高了 28.7 倍,基于端到端贝叶斯语言模型的方法设计大型、多样化的高亲和力抗体库
结合亲和力提高了 28.7 倍,基于端到端贝叶斯语言模型的方法设计大型、多样化的高亲和力抗体库
|
12月前
|
人工智能 自然语言处理 安全
表现优于 GPT-4,ChemCrow 集成 13 种化学工具,增强大型语言模型的化学性能
表现优于 GPT-4,ChemCrow 集成 13 种化学工具,增强大型语言模型的化学性能
238 0
|
12月前
|
机器学习/深度学习 人工智能 缓存
连夜卷出 | 超越所有YOLO检测模型,mmdet开源当今最强最快目标检测模型!(二)
连夜卷出 | 超越所有YOLO检测模型,mmdet开源当今最强最快目标检测模型!(二)
459 0
|
12月前
|
Go 计算机视觉 开发者
连夜卷出 | 超越所有YOLO检测模型,mmdet开源当今最强最快目标检测模型!(一)
连夜卷出 | 超越所有YOLO检测模型,mmdet开源当今最强最快目标检测模型!(一)
441 0
|
12月前
|
机器学习/深度学习 存储 缓存
VLDB 2022最佳研究论文:克服通信挑战,新框架SANCUS实现GNN高效训练
VLDB 2022最佳研究论文:克服通信挑战,新框架SANCUS实现GNN高效训练
|
机器学习/深度学习 编解码 测试技术
解决CNN固有缺陷, CCNN凭借单一架构,实现多项SOTA
解决CNN固有缺陷, CCNN凭借单一架构,实现多项SOTA
123 0
|
机器学习/深度学习 自然语言处理 算法
星际争霸II协作对抗基准超越SOTA,新型Transformer架构解决多智能体强化学习问题
星际争霸II协作对抗基准超越SOTA,新型Transformer架构解决多智能体强化学习问题
140 0