一张图的一百种 “活” 法 | MMClassification 数据增强介绍第二弹

简介: 既然数据增强手段能够提高模型的泛化能力,那么我们自然希望通过一系列数据增强的组合获得最优的泛化效果,从而衍生出了一系列组合增强手段,这里我们介绍其中最著名也最常用的两个手段,AutoAugment 和 RandAugment。

上一篇中,我们介绍了包括随机翻转、裁剪等在内的一系列基础数据增强手段

640.jpg

而在本篇中,我们将会介绍一些较为进阶的组合增强手段和图像混合增强

640.png


1. 组合数据增强



既然数据增强手段能够提高模型的泛化能力,那么我们自然希望通过一系列数据增强的组合获得最优的泛化效果,从而衍生出了一系列组合增强手段,这里我们介绍其中最著名也最常用的两个手段,AutoAugment 和 RandAugment。


AutoAugment


原论文:AutoAugment: Learning Augmentation Policies from Data

链接:

https://arxiv.org/abs/1805.09501


在 AutoAugment 提出以前,深度学习的研究重点往往是网络结构的优化,连续多年的主流数据增强手段还是随机裁剪缩放(RandomResizedCrop)、随机翻转(RandomFlip),以及渐渐淡出视野的随机光照变换(Lighting)。毕竟数据变换的组合、参数多种多样,而且还会和具体的数据集有关,选择一组合适的数据增强组合的确是一件苦力活。


AutoAugment 提出了一个具有较强可操作性的数据增强组合搜索方法。


1.使用一个 RNN 网络对数据增强空间进行采样,也就是选择一组数据增强策略和相应的参数。每组方法包含 5 个子策略组,每个子策略组包含连续的两个数据增强方法。


2. 使用这一组数据增强策略训练某一个固定结构的神经网络,进而得到其对训练精度的增益。


3. 将该增益作为强化学习的 reward,从而使用 Proximal Policy Optimization(PPO)强化学习算法对用于采样的 RNN 神经网络进行优化。


4. 学习完成后,选择效果最好的 5 组策略,共计 25 个子策略组,综合而成最后的数据增强手段。


在实际使用时,对每张图片,都会随机应用这 25 个子策略组中的一组子策略进行处理。

640.png

那么问题来了,这些子策略应该如何获得呢?MMClassification 中提供了针对 ImageNet 数据集的 AutoAugment 策略;对于其他数据集,可以参照原论文训练学习新的策略,也可以直接使用 ImageNet 数据集的策略,后文中也会提到其他的替代方案。


如果在使用 MMClassification 进行训练时需要使用 ImageNet 的 AutoAugment 策略,可在配置文件中使用以下写法:

# 该基配置文件位于 configs/_base_/datasets/pipelines
# 假设本配置文件位于 configs/_base_/datasets,可使用以下相对路径
_base_ = ['./pipelines/auto_aug.py']
train_pipeline = [
    ...
    # 使用 {{_base_.xxx}} 的写法,可以调用基配置文件中的变量
    dict(type='AutoAugment', policies={{_base_.policy_imagenet}}),
    ...
]

这里我们提供了一个小 demo 脚本来演示效果:

import mmcv
from mmcls.datasets import PIPELINES
# 示例策略,包含两个子策略组,每个子策略组包含两个数据增强方法
# 为了便于演示,我们设所有的概率都为 1.0
demo_policies = [
    [
        dict(type='Posterize', bits=4, prob=1.),  # 降低图片位数
        dict(type='Rotate', angle=30., prob=1.)   # 旋转
    ],
    [
        dict(type='Solarize', thr=256 / 9 * 4, prob=1.),  # 翻转部分暗色像素
        dict(type='AutoContrast', prob=1.)                # 自动调整对比度
    ],
]
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
    type='AutoAugment',
    policies=demo_policies,
    hparams=dict(pad_val=0),  # 设定一些所有子策略共用的参数,如填充值(pad_val)
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)

关于配置选项更具体的说明,可以阅读 MMClassficiation 的 官方文档。


官方文档链接:

https://mmclassification.readthedocs.io/zh_CN/latest/api.html#mmcls.datasets.pipelines.AutoAugment


RandAugment


原论文:RandAugment: Practical automated data augmentation with a reduced search space

论文链接:

https://arxiv.org/abs/1909.13719


AutoAugment 将强化学习引入到了数据增强策略的设定中,但是这也意味着针对一个新的数据集学习一套新的数据增强策略会十分复杂且消耗计算资源。


另一方面,为了节约计算资源,大家一般会选择一个训练集子集或者一个小模型来进行数据增强策略的学习。但研究表明,将小数据集 / 小模型上学习到的增强策略应用于完整的任务往往无法获得最优的效果。


为了解决这些问题,研究人员们提出了 RandAugment 这一方法。该方法尽可能地缩减了需要学习的参数, AutoAugment 的每组策略包含5个子策略组,共计10个策略,每个策略又包含幅度和概率两个参数,而  RandAugment 将所有的策略幅度统一为一个值,概率统一设定为平均,因而可以调整的参数只有幅度和每次应用的增强项数这两个变量。


这样一来,不需要应用复杂的强化学习,只需要简单的网格搜索就可以获得针对特定数据集和网络最优的数据增强策略。


具体而言, RandAugment 设置了一个包含各种数据增强变换的集合,对每张图片,随机应用 K 个数据增强变换,每个变换的幅度都是预先设定的幅度(或者在预设幅度的基础上添加一个随机浮动)。

640.png

在 MMClassification 中,我们同样提供了这个数据增强变换集合,使用时可在配置文件中使用以下写法:

# 该基配置文件位于 configs/_base_/datasets/pipelines
# 假设本配置文件位于 configs/_base_/datasets,可使用以下相对路径
_base_ = ['./pipelines/rand_aug.py']
train_pipeline = [
    ...
    # 使用 {{_base_.xxx}} 的写法,可以调用基配置文件中的变量
    dict(
        type='RandAugment',
        policies={{_base_.rand_increasing_policies}},
        num_policies=2,     # 每次随机应用的数据增强变换数
        total_level=10,     # 最大变换幅度 10
        magnitude_level=9,  # 每个数据增强变换的幅度
        magnitude_std=0.5,  # 每次变换时,幅度随机浮动的方差
        ),
    ...
]

同样的,这里我们也提供了一个小 demo 脚本来演示效果:

import mmcv
from mmcls.datasets import PIPELINES
# 示例策略,包含三个数据增强方法
demo_policies = [
    dict(type='Invert'),
    dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
    dict(type='Shear',
         magnitude_key='magnitude',
         magnitude_range=(0, 0.3),
         direction='horizontal'),
]
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
    type='RandAugment',
    policies=demo_policies,
    num_policies=2,     # 每次随机应用的数据增强变换数
    total_level=10,     # 最大变换幅度 10
    magnitude_level=9,  # 每个数据增强变换的幅度
    magnitude_std=0.5,  # 每次变换时,幅度随机浮动的方差
    hparams=dict(pad_val=0),  # 设定一些所有子策略共用的参数,如填充值(pad_val)
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)


2. 图像混合增强



以上,我们介绍的数据增强方法都是针对单张图片进行的变换,而在最近的研究中,使用复数张图片进行混合的增强手段也有着广泛的应用。


这里我们介绍最常用的两种方法,Mixup 和 CutMix。


Mixup


原论文:mixup: Beyond Empirical Risk Minimization

论文链接:

https://arxiv.org/abs/1710.09412


在监督学习任务中,我们通常可以假设样本为随机变量 X,标签为随机变量 Y,我们希望做的是根据 X-Y 联合分布,优化出一个最佳的 f(x)。但现实中,我们根本无法获得 X-Y 的联合分布,具体到分类任务,也就是我们无法穷尽像素的组合与标签之间的关系,因而只能用一系列样本来拟合这一联合分布。


Mixup 这一增强手段的目的,也即在于优化 “用样本拟合联合分布” 这一步。它采用了一个十分简单的操,将两张图片直接进行线性组合,对应的,标签也进行线性组合。其中,叠加在背景图片上的图片强度按照 Beta(a,a)分布随机采样获得,a 为超参数。

640.png


CutMix


原论文:CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

链接:

https://arxiv.org/abs/1905.04899?context=cs.CV


CutMix 和 Mixup 的思路一致,其目的都是为了优化“用样本拟合联合分布”的方法。但 CutMix 的作者认为,简单地将两张图片叠加是不自然的,对于分类任务虽然有一定的效果提升,但对检测等下游任务,都会导致性能的下降。


CutMix 与 Mixup 的不同之处在于,CutMix 通过裁取一张图片中的部分区域,以类似拼贴画的方式贴到另一张图片上,来实现图片的混合。其中,裁切图片的面积比例按照 Beta(a,a)分布随机采样获得,a 为超参数。

640.png

在 MMClassification 中的应用


由于图像混合增强不同于其他的对单张图片作用的数据增强手段,其通常需要两张图片,因而使用方法也不同于其他增强方式在 pipeline 中配置的方式,而是在 model 的 train_cfg 中配置,如下:

model = dict(
    ...
    # alpha 为 Mixup 和 CutMix 中的超参数,与图片随机叠加的强度或随机裁剪的面积相关
    # prob 表示对一个 batch 的图像,采用对应增强方法的概率,总和不能超过 1.0
    train_cfg=dict(augments=[
        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
    ]))

同样地,我们提供了一份 demo 脚本来展示效果(由于混合对象是随机选取的,可能出现图像自混合,从而没有效果的现象,可重复尝试):

import torch
import numpy as np
import mmcv
from mmcls.models.utils import Augments
img1 = mmcv.imresize(mmcv.imread("./kittens.jpg"), (224, 224))
img2 = mmcv.imresize(mmcv.imread("./house.jpg"), (224, 224))
imgs = [torch.from_numpy(img.transpose(2, 0, 1)) for img in (img1, img2)]
img_batch = torch.stack(imgs)
labels = torch.tensor([0, 1])
# ---- Mixup ----
mixup_cfg = dict(type='BatchMixup', alpha=0.8, num_classes=2, prob=1.0)
mixup = Augments([mixup_cfg])
img_batch_aug, label_batch_aug = mixup(img_batch, labels)
img1_aug = img_batch_aug[0].numpy().transpose(1, 2, 0).astype('uint8')
img2_aug = img_batch_aug[1].numpy().transpose(1, 2, 0).astype('uint8')
mmcv.imshow(img1_aug, win_name=f'mixup-{label_batch_aug[0]}')
mmcv.imshow(img2_aug, win_name=f'mixup-{label_batch_aug[1]}')
# ---- CutMix ----
cutmix_cfg = dict(type='BatchCutMix', alpha=1.0, num_classes=2, prob=1.0)
cutmix = Augments([cutmix_cfg])
img_batch_aug, label_batch_aug = cutmix(img_batch, labels)
img1_aug = img_batch_aug[0].numpy().transpose(1, 2, 0).astype('uint8')
img2_aug = img_batch_aug[1].numpy().transpose(1, 2, 0).astype('uint8')
mmcv.imshow(img1_aug, win_name=f'cutmix-{label_batch_aug[0]}')
mmcv.imshow(img2_aug, win_name=f'cutmix-{label_batch_aug[1]}')


文章来源:公众号【OpenMMLab】

2021-11-19 21:31


目录
相关文章
|
6月前
|
人工智能 搜索推荐
StableIdentity:可插入图像/视频/3D生成,单张图即可变成超人,可直接与ControlNet配合使用
【2月更文挑战第17天】StableIdentity:可插入图像/视频/3D生成,单张图即可变成超人,可直接与ControlNet配合使用
109 2
StableIdentity:可插入图像/视频/3D生成,单张图即可变成超人,可直接与ControlNet配合使用
|
机器学习/深度学习 图计算 图形学
同构图、异构图、属性图、非显式图
同构图(Homogeneous Graph)、异构图(Heterogeneous Graph)、属性图(Property Graph)和非显式图(Graph Constructed from Non-relational Data)。 (1)同构图:
1955 0
同构图、异构图、属性图、非显式图
|
6月前
|
前端开发 计算机视觉
InstantStyle,无需训练,风格保留文生图
InstantStyle 是一个通用框架,它采用两种简单但有效的技术来实现风格和内容与参考图像的有效分离。
|
6月前
|
算法 图形学 UED
Unity Hololens2开发|(八)MRTK3空间操作 BoundsControl(边界控制)
Unity Hololens2开发|(八)MRTK3空间操作 BoundsControl(边界控制)
聊天框(番外篇)—如何实现@功能的整体删除
上一篇文章中,我们已经初步实现了聊天输入框,但其@功能是不完善的,例如无法整体删除、无法获取除用户名以外的数据(假设用户名不是唯一的)。有问题就要想办法解决,在网上百度了一圈后,倒是有一些收获。本文就着重解决@的整体删除以及获取额外数据。
1103 0
聊天框(番外篇)—如何实现@功能的整体删除
|
JavaScript 测试技术 Python
WebUI自动化测试中隐藏的元素如何操作?三种元素等待方式如何理解?
WebUI自动化测试中隐藏的元素如何操作?三种元素等待方式如何理解?
76 0
|
存储 数据可视化 atlas
maftools | 从头开始绘制发表级oncoplot(瀑布图)
maftools | 从头开始绘制发表级oncoplot(瀑布图)
409 0
|
算法
基于自动亮度对比度增强功能的可逆数据隐藏(Matlab代码实现)
基于自动亮度对比度增强功能的可逆数据隐藏(Matlab代码实现)
114 0
|
人工智能 自然语言处理 文字识别
理解指向,说出坐标,Shikra开启多模态大模型参考对话新维度
理解指向,说出坐标,Shikra开启多模态大模型参考对话新维度
201 0
|
机器学习/深度学习
GraphCL:基于数据增强的图对比学习
GraphCL:基于数据增强的图对比学习
666 0
GraphCL:基于数据增强的图对比学习