MAE 自监督算法介绍和基于 EasyCV 的复现

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 自监督学习(Self-Supervised Learning)能利用大量无标注的数据进行表征学习,然后在特定下游任务上对参数进行微调。通过这样的方式,能够在较少有标注数据上取得优于有监督学习方法的精度。近年来,自监督学习受到了越来越多的关注,如Yann Lecun也在 AAAI 上讲 Self-Supervised Learning 是未来的大势所趋。在CV领域涌现了如SwAV、MOCO、DINO、MoBY等一系列工作。MAE是kaiming继MOCO之后在自监督学习领域的又一力作。首先,本文会对MAE进行解读,然后基于EasyCV库的精度复现过程及其中遇到的一些问题作出解答。

导言

自监督学习(Self-Supervised Learning)能利用大量无标注的数据进行表征学习,然后在特定下游任务上对参数进行微调。通过这样的方式,能够在较少有标注数据上取得优于有监督学习方法的精度。近年来,自监督学习受到了越来越多的关注,如Yann Lecun也在 AAAI 上讲 Self-Supervised Learning 是未来的大势所趋。在CV领域涌现了如SwAV、MOCO、DINO、MoBY等一系列工作。MAE是kaiming继MOCO之后在自监督学习领域的又一力作。首先,本文会对MAE进行解读,然后基于EasyCV库的精度复现过程及其中遇到的一些问题作出解答。

概述

MAE的做法很简单:随机mask掉图片中的一些patch,然后通过模型去重建这些丢失的区域。包括两个核心的设计:1.非对称编码-解码结构 2.用较高的掩码率(75%)。通过这两个设计MAE在预训练过程中可以取得3倍以上的训练速度和更高的精度,如ViT-Huge能够通过ImageNet-1K数据上取得87.8%的准确率。

模型拆解

MAE属于自编码器(AutoEncoder)的一种,由编码器和解码器两个部分组成。类似于常见的自编码器,MAE会先通过编码器将图片patch映射到隐空间。然后,基于解码器将隐空间上的特征变量重构成图片patch。和常见自编码器的区别是非对称的编码解码结构。这个非对称性主要体现在以下两点:

  1. 轻量化的解码器结构
  2. 在编码器阶段,仅将未被mask掉的图片patch作为输入。在解码器阶段会将编码器输出的隐变量和mask token共同作为输入去重建完成的图片。

image.png

掩码策略

首先,直接采用ViT的做法将图片分成不重叠的patch(如vit-b会将图片划分成16x16的图像块),然后通过均匀采样策略对这些patch进行采样,并丢弃未被选中的部分。MAE所采用的掩码策略有如下两个特点:

1.在算法中,使用了75%的masking ratio来丢弃图片patch。作者指出,通过high masking ratio可以有效减少输入的冗余程度,使重建任务不能够通过简单的参考邻近patch来完成。文中,也通过实验证明了这一观点。
image.png
关于Masking ratio的实验是MAE最精彩的一部分,随着mask ratio的增加,fine-tuning和linear proing的精度逐渐攀升,甚至到75%还没有下降,这一点打破了BERT(15%)、BEiT(40%)的做法,进一步将mask 预训练方式在NLP领域的成功在CV领域实现复制。

2.采用了均匀采样策略可以有效的避免potential center bias(丢弃掉的patch都靠近图片中心)。对mask策略的消去实验如下表所示。
image.png

编码器

MAE encoder采用的是ViT结构。在对图像patch进行采样后,仅保留25%未被mask的图像patch作为输入,通过linear Projection进行编码后,加上positional embedding,然后输入到一系列的Transformer blocks中。相比于Bert中用mask token来代替被mask区域的做法,MAE encoder直接舍弃掉了mask的部分,通过这种方式可以有效的减少预训练过程中需要消耗的计算资源和训练时间。

文中,作者对编码器是否保留mask token进行了消融实验,可以看出在编码器阶段舍弃mask token不会对预训练模型的表征能力造成影响,同时能够显著的加速训练进程。
image.png

解码器

MAE decoder由一连串的Transfomer block组成。和encoder不同的是,MAE decoder的输入不仅包括未被mask的图像patch经过encoder编码后的特征,还包括了被mask掉的部分。对于mask掉部分的输入,会用一个共享参数,且可学习的mask token代替作为输入。除此之外,为了保证不同的mask token能够区分在图像中的不同位置,在输入到decoder之前,会对整体的输入加上positional embedding。

在MAE中,解码器仅会在预训练阶段用于图片的重建工作。文中采用了轻量化的解码器结构,对于每个token的计算量仅有相对于解码器的10%以下。通过这种设计,就算在解码阶段用了完整数量的token作为输入,对计算资源的消耗也不会显著增加。

文中,作者对解码器的depth和width两个维度进行对比实验,可以看出一个较轻量化的解码器,就足以是模型学习到有效的表征。
image.png

重建目标

MAE预训练任务的目标是重建被mask掉的像素值。MAE decoder输出关于每个图像patch的表征后,会经过一个linear projection层映射成与图像像素数目相同维度的向量(PxPx3)。仅采用MSE作为损失函数,计算预测向量和被mask掉像素值之前的MSE loss。

需要额外指出的是,作者使用了归一化后的图像patch作为重建的目标。通过实验证明,这种做法可以提升模型的表征能力。
image.png

模型评价

文中除了从linear probing和Finetuning两个角度对模型的表征能力做出评价外,还采用了Partial Fine-tuning的方式进行评价,相比于linear probing这种之前普遍采用的评价方式,能够更好的反映预训练模型对非线性特征的表征能力。从下图可以看出,MAE算法仅仅对一个transformer block进行fintune精度就从73.5%提升到81%。同时与MOCOv3相比,MOCOv3虽然在linear probing的时候具有更高的精度,但是在partial fine-tuning时,MAE的精度都要高于MOCOv3。可以看出,MAE虽然对线性特征的表征能力要弱于MOCOv3,但是具有更好的非线性特征表征能力。
image.png

EasyCV 介绍

EasyCV是阿里巴巴开源的基于Pytorch,以自监督学习和Transformer技术为核心的 all-in-one 视觉算法建模工具。在数据层面,EasyCV提供了提供了不同数据源(data_source)的抽象,支持多种开源数据集例如Cifar、ImageNet、CoCo等,并将各种数据预处理抽象成若干独立的pipeline,可以通过配置文件灵活的配置数据预处理流程。在API层面,提供了统一的训练、评估、模型导出、预测的API。因此,基于EasyCV,仅需要实现模型部分的代码,就可以很便捷的完成MAE的复现。

除此之外,EasyCV支持aliyun PAI产品中方便的进行部署(如PAI-DLC),无需多余的修改即可在DLC上同时进行多机或者多组实验,加快复现进度。

复现过程 & 踩坑总结

接下来我们介绍如何在EasyCV框架中进行MAE算法的复现和踩坑总结,首先,说明一下预训练的整体流程。
1.将输入图像划分成不同的patch,并将patch经过Linear Projection进行映射,再加上positional embedding得到image token

# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]

2.将image token按75%的比例进行随机mask,通过随机生成的张量noise进行argsort操作的方式来完成对image patch的随机mask。其中,需要注意,该函数中额外传回两个参数mask和ids_restore。mask记录了mask patch在原始图片中的位置,用于后续损失函数的计算。ids_restore记录了传入encoder的image token在原始图片中的位置,用于后续再decoder前进行unshuffle操作。

def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(
            x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

3.将保留的image token输入到encoder得到image embeding

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

4.将image embeding和mask token一起进行unshuffle操作,再加上positional embedding后,输入到decoder中

# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
    x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
    x_,
    dim=1,
    index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)

# add pos embed
x = x + self.decoder_pos_embed

5.将输出的vector与归一化后的image patch计算mse loss,并反向传播更新梯度。在计算loss时,有两个需要注意的点。1、首先,需要对作为target的图像patch做归一化。2、在计算损失函数时,只对mask patch的部分计算损失函数。

    def forward_loss(self, imgs, pred, mask):
        """compute loss
        Args:
            imgs: (N, 3, H, W)
            pred: (N, L, p*p*3)
            mask: (N, L), 0 is keep, 1 is remove,
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target)**2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

精度复现

参考https://github.com/facebookresearch/mae,我们在单机八卡V100的配置下,对ViT-base和ViT-large的在ImageNet1K上fintune的精度进行了复现。结果如下表所示。

Algorithm ImageNet1K Top-1(%) ImageNet1K Top-1(%) config
vit-b 400 epoch 83.13 mae_vit_base_patch16_8xb64_100e_lrdecay075_fintune
vit-b 1600 epoch 83.55 mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune
vit-l 1600 epoch 85.70 mae_vit_large_patch16_8xb16_50e_lrdecay075_fintune

下面分享一下在复现过程中遇到的一些问题和调参,如有问题请指出。

  1. 在fintune时,MAE的实现使用了mixup+cutmix的数据增广方式,若仅使用mixup精度会下降。
  2. 在fintune时,MAE中使用了所有token特征求平均的方式作为分类head的输入,而cls token作为输入时精度会有下降。
  3. 在预训练过程中,确保使用了足够大的weight_decay(如官方设为0.05),否则在下游任务fintune时,很容易出现梯度爆炸的问题。而在下游分类任务fintune时,设置一个较小的weight,精度会有一些提升。(PS 在复现vit-l时,在pretrain时设置weight_decay 0.01,在fintune时会出现梯度爆炸)

下表展示了vit-b模型的复现过程上述过程的精度提升

parameter setting ImageNet1K Top-1(%)
vit-b 1600 epoch(mixup,cls token) 83.21
vit-b 1600 epoch(mixup+cutmix,cls token) 83.36
vit-b 1600 epoch(mixup+cutmix,global_pool) 83.55

我们在开源框架EasyCV中复现了MAE算法。详细参数配置和实验日志参考github上的自监督modelzoo

Tutorial

接下来,我们将通过一个实际的例子介绍如何基于EasyCV进行MAE算法的预训练和微调,也可以在该链接查看详细步骤。

一、安装依赖包

如果是在本地开发环境运行,可以参考该链接安装环境。若使用PAI-DSW进行实验则无需安装相关依赖,在PAI-DSW docker中已内置相关环境。

二、数据准备

自监督训练只需要提供无标注图片即可进行, 你可以下载ImageNet数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径p,以及一个文件列表,文件列表中是每个图片相对图片目录p的路径。

图片文件夹结构示例如下, 文件夹路径为./images

images/
├── 0001.jpg
├── 0002.jpg
├── 0003.jpg
|...
└── 9999.jpg

文件列表内容如下:

0001.jpg
0002.jpg
0003.jpg
...
9999.jpg

为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压:

wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/imagenet_raw_demo/imagenet_raw_demo.tar.gz
tar -zxf imagenet_raw_demo.tar.gz
mv imagenet_raw_demo  imagenet_raw

三、模型预训练

以vit-base为示例。在EasyCV中,使用配置文件的形式来实现对模型参数、数据输入及增广方式、训练策略的配置,仅通过修改配置文件中的参数设置,就可以完成实验配置进行训练。可以直接下载示例配置文件。

rm -rf mae_vit_base_patch16_8xb64_1600e.py
wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_1600e.py

查看easycv安装位置

# 查看easycv安装位置
import easycv
print(easycv.__file__)

执行训练命令

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_1600e.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

四、模型微调

1、对上一步得到的预训练模型的字段进行修改,以便用于fintune任务。

import torch 
weight_path = 'work_dir/selfsup/jpg/mae/epoch_5.pth'
state_dict = torch.load(weight_path)['state_dict']
state_dict_out = {}
for key in state_dict:
    state_dict_out[key.replace('encoder.','')] = state_dict[key]
torch.save(state_dict_out,weight_path)

2、下载分类任务示例配置文件

rm -rf mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py
wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/doc/easycv/configs/selfsup/mae/mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py

3、执行训练命令

python -m torch.distributed.launch --nproc_per_node=1 --master_port=29930 \
/home/pai/lib/python3.6/site-packages/easycv/tools/train.py mae_vit_base_patch16_8xb64_100e_lrdecay065_fintune.py --work_dir work_dir/selfsup/jpg/mae --launcher pytorch

END

后续EasyCV会就SOTA论文复现进行系列的工作介绍,欢迎大家关注和使用,欢迎大家各种维度的反馈和改进建议以及技术讨论,同时我们十分欢迎和期待对开源社区建设感兴趣的同行一起参与共建。

项目开源地址:https://github.com/alibaba/EasyCV
钉钉答疑交流群:41783266

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
机器学习/深度学习 传感器 运维
综述:弱监督下的异常检测算法
# 一、前言 文章标题是: Weakly Supervised Anomaly Detection: A Survey 这是一篇针对“弱监督”异常检测的综述。 其中弱监督异常检测 简称为 WSAD - 论文链接:https://arxiv.org/abs/2302.04549 - 代码链接:https://github.com/yzhao062/wsad # 二、问题 针对异常检测问题,其
432 1
综述:弱监督下的异常检测算法
|
5月前
|
算法 数据挖掘 知识图谱
LINE算法复现 图表示学习 基于line 算法的节点分类 聚类显示 完整代码+数据
LINE算法复现 图表示学习 基于line 算法的节点分类 聚类显示 完整代码+数据
20 0
|
5月前
|
算法 数据挖掘
Struc2vec 算法复现 基于 struc2vec 节点分类聚类实战 完整代码+数据
Struc2vec 算法复现 基于 struc2vec 节点分类聚类实战 完整代码+数据
28 0
|
5月前
|
算法
Deepwalk算法复现: 基于 deepwalk的网络节点分类 聚类分析 完整代码+数据
Deepwalk算法复现: 基于 deepwalk的网络节点分类 聚类分析 完整代码+数据
44 0
|
7月前
|
传感器 机器学习/深度学习 算法
【WSN】移动传感器网络动态覆盖的分布式防拥塞算法matlab复现
【WSN】移动传感器网络动态覆盖的分布式防拥塞算法matlab复现
|
8月前
|
算法 调度
(文章复现)基于灰狼算法(GWO)的交直流混合微网经济调度matlab代码
参考文献: [1]高瑜,黄森,陈刘鑫等.基于改进灰狼算法的并网交流微电网经济优化调度[J].科学技术与工程, 2020,20(28):11605-11611. [2]邓长征,冯朕,邱立等.基于混沌灰狼算法的交直流混合微网经济调度[J].电测与仪表, 2020, 57(04):99-107.
|
传感器 机器学习/深度学习 人工智能
史上最全综述 | 3D目标检测算法汇总!(单目/双目/LiDAR/多模态/时序/半弱自监督)(下)
近年来,自动驾驶因其减轻驾驶员负担、提高行车安全的潜力而受到越来越多的关注。在现代自动驾驶系统中,感知系统是不可或缺的组成部分,旨在准确估计周围环境的状态,并为预测和规划提供可靠的观察结果。3D目标检测可以智能地预测自动驾驶车辆附近关键3D目标的位置、大小和类别,是感知系统的重要组成部分。本文回顾了应用于自动驾驶领域的3D目标检测的进展。
史上最全综述 | 3D目标检测算法汇总!(单目/双目/LiDAR/多模态/时序/半弱自监督)(下)
|
8月前
|
自然语言处理 算法 网络架构
YAKE!无监督关键字抽取算法解读
RAKE遵循由启发式方法支持的无监督方法,该方法可以在短时间内轻松扩展到不同的集合、域和语言。
171 0
|
9月前
|
算法 调度 Python
改进的多目标差分进化算法在电力系统环境经济调度中的应用(Python代码实现)【电气期刊论文复现】
改进的多目标差分进化算法在电力系统环境经济调度中的应用(Python代码实现)【电气期刊论文复现】
|
9月前
|
算法
(文章复现)《基于改进教与学算法的配电网无功优化》
在解决配电网无功优化问题中,智能启发式算法得到了广泛应用。采用了教与学优化算法求解含分布式电源的配电网无功优化问题。现将精英策略引入教与学算法,改进了该算法的搜索能力,提高了求解的稳定性。以有功网损最小为目标建立了无功优化模型,并基于改进的IEEE 33母线配电网系统进行仿真计算,结果验证了基于精英策略改进的教与学算法具有更强的收敛性和鲁棒性,能获得更好的优化结果。

热门文章

最新文章