一张图的一百种 “活” 法 | MMClassification 数据增强介绍第二弹

简介: 既然数据增强手段能够提高模型的泛化能力,那么我们自然希望通过一系列数据增强的组合获得最优的泛化效果,从而衍生出了一系列组合增强手段,这里我们介绍其中最著名也最常用的两个手段,AutoAugment 和 RandAugment。

上一篇中,我们介绍了包括随机翻转、裁剪等在内的一系列基础数据增强手段

640.jpg

而在本篇中,我们将会介绍一些较为进阶的组合增强手段和图像混合增强

640.png


1. 组合数据增强



既然数据增强手段能够提高模型的泛化能力,那么我们自然希望通过一系列数据增强的组合获得最优的泛化效果,从而衍生出了一系列组合增强手段,这里我们介绍其中最著名也最常用的两个手段,AutoAugment 和 RandAugment。


AutoAugment


原论文:AutoAugment: Learning Augmentation Policies from Data

链接:

https://arxiv.org/abs/1805.09501


在 AutoAugment 提出以前,深度学习的研究重点往往是网络结构的优化,连续多年的主流数据增强手段还是随机裁剪缩放(RandomResizedCrop)、随机翻转(RandomFlip),以及渐渐淡出视野的随机光照变换(Lighting)。毕竟数据变换的组合、参数多种多样,而且还会和具体的数据集有关,选择一组合适的数据增强组合的确是一件苦力活。


AutoAugment 提出了一个具有较强可操作性的数据增强组合搜索方法。


1.使用一个 RNN 网络对数据增强空间进行采样,也就是选择一组数据增强策略和相应的参数。每组方法包含 5 个子策略组,每个子策略组包含连续的两个数据增强方法。


2. 使用这一组数据增强策略训练某一个固定结构的神经网络,进而得到其对训练精度的增益。


3. 将该增益作为强化学习的 reward,从而使用 Proximal Policy Optimization(PPO)强化学习算法对用于采样的 RNN 神经网络进行优化。


4. 学习完成后,选择效果最好的 5 组策略,共计 25 个子策略组,综合而成最后的数据增强手段。


在实际使用时,对每张图片,都会随机应用这 25 个子策略组中的一组子策略进行处理。

640.png

那么问题来了,这些子策略应该如何获得呢?MMClassification 中提供了针对 ImageNet 数据集的 AutoAugment 策略;对于其他数据集,可以参照原论文训练学习新的策略,也可以直接使用 ImageNet 数据集的策略,后文中也会提到其他的替代方案。


如果在使用 MMClassification 进行训练时需要使用 ImageNet 的 AutoAugment 策略,可在配置文件中使用以下写法:

# 该基配置文件位于 configs/_base_/datasets/pipelines
# 假设本配置文件位于 configs/_base_/datasets,可使用以下相对路径
_base_ = ['./pipelines/auto_aug.py']
train_pipeline = [
    ...
    # 使用 {{_base_.xxx}} 的写法,可以调用基配置文件中的变量
    dict(type='AutoAugment', policies={{_base_.policy_imagenet}}),
    ...
]

这里我们提供了一个小 demo 脚本来演示效果:

import mmcv
from mmcls.datasets import PIPELINES
# 示例策略,包含两个子策略组,每个子策略组包含两个数据增强方法
# 为了便于演示,我们设所有的概率都为 1.0
demo_policies = [
    [
        dict(type='Posterize', bits=4, prob=1.),  # 降低图片位数
        dict(type='Rotate', angle=30., prob=1.)   # 旋转
    ],
    [
        dict(type='Solarize', thr=256 / 9 * 4, prob=1.),  # 翻转部分暗色像素
        dict(type='AutoContrast', prob=1.)                # 自动调整对比度
    ],
]
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
    type='AutoAugment',
    policies=demo_policies,
    hparams=dict(pad_val=0),  # 设定一些所有子策略共用的参数,如填充值(pad_val)
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)

关于配置选项更具体的说明,可以阅读 MMClassficiation 的 官方文档。


官方文档链接:

https://mmclassification.readthedocs.io/zh_CN/latest/api.html#mmcls.datasets.pipelines.AutoAugment


RandAugment


原论文:RandAugment: Practical automated data augmentation with a reduced search space

论文链接:

https://arxiv.org/abs/1909.13719


AutoAugment 将强化学习引入到了数据增强策略的设定中,但是这也意味着针对一个新的数据集学习一套新的数据增强策略会十分复杂且消耗计算资源。


另一方面,为了节约计算资源,大家一般会选择一个训练集子集或者一个小模型来进行数据增强策略的学习。但研究表明,将小数据集 / 小模型上学习到的增强策略应用于完整的任务往往无法获得最优的效果。


为了解决这些问题,研究人员们提出了 RandAugment 这一方法。该方法尽可能地缩减了需要学习的参数, AutoAugment 的每组策略包含5个子策略组,共计10个策略,每个策略又包含幅度和概率两个参数,而  RandAugment 将所有的策略幅度统一为一个值,概率统一设定为平均,因而可以调整的参数只有幅度和每次应用的增强项数这两个变量。


这样一来,不需要应用复杂的强化学习,只需要简单的网格搜索就可以获得针对特定数据集和网络最优的数据增强策略。


具体而言, RandAugment 设置了一个包含各种数据增强变换的集合,对每张图片,随机应用 K 个数据增强变换,每个变换的幅度都是预先设定的幅度(或者在预设幅度的基础上添加一个随机浮动)。

640.png

在 MMClassification 中,我们同样提供了这个数据增强变换集合,使用时可在配置文件中使用以下写法:

# 该基配置文件位于 configs/_base_/datasets/pipelines
# 假设本配置文件位于 configs/_base_/datasets,可使用以下相对路径
_base_ = ['./pipelines/rand_aug.py']
train_pipeline = [
    ...
    # 使用 {{_base_.xxx}} 的写法,可以调用基配置文件中的变量
    dict(
        type='RandAugment',
        policies={{_base_.rand_increasing_policies}},
        num_policies=2,     # 每次随机应用的数据增强变换数
        total_level=10,     # 最大变换幅度 10
        magnitude_level=9,  # 每个数据增强变换的幅度
        magnitude_std=0.5,  # 每次变换时,幅度随机浮动的方差
        ),
    ...
]

同样的,这里我们也提供了一个小 demo 脚本来演示效果:

import mmcv
from mmcls.datasets import PIPELINES
# 示例策略,包含三个数据增强方法
demo_policies = [
    dict(type='Invert'),
    dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
    dict(type='Shear',
         magnitude_key='magnitude',
         magnitude_range=(0, 0.3),
         direction='horizontal'),
]
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
    type='RandAugment',
    policies=demo_policies,
    num_policies=2,     # 每次随机应用的数据增强变换数
    total_level=10,     # 最大变换幅度 10
    magnitude_level=9,  # 每个数据增强变换的幅度
    magnitude_std=0.5,  # 每次变换时,幅度随机浮动的方差
    hparams=dict(pad_val=0),  # 设定一些所有子策略共用的参数,如填充值(pad_val)
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)


2. 图像混合增强



以上,我们介绍的数据增强方法都是针对单张图片进行的变换,而在最近的研究中,使用复数张图片进行混合的增强手段也有着广泛的应用。


这里我们介绍最常用的两种方法,Mixup 和 CutMix。


Mixup


原论文:mixup: Beyond Empirical Risk Minimization

论文链接:

https://arxiv.org/abs/1710.09412


在监督学习任务中,我们通常可以假设样本为随机变量 X,标签为随机变量 Y,我们希望做的是根据 X-Y 联合分布,优化出一个最佳的 f(x)。但现实中,我们根本无法获得 X-Y 的联合分布,具体到分类任务,也就是我们无法穷尽像素的组合与标签之间的关系,因而只能用一系列样本来拟合这一联合分布。


Mixup 这一增强手段的目的,也即在于优化 “用样本拟合联合分布” 这一步。它采用了一个十分简单的操,将两张图片直接进行线性组合,对应的,标签也进行线性组合。其中,叠加在背景图片上的图片强度按照 Beta(a,a)分布随机采样获得,a 为超参数。

640.png


CutMix


原论文:CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

链接:

https://arxiv.org/abs/1905.04899?context=cs.CV


CutMix 和 Mixup 的思路一致,其目的都是为了优化“用样本拟合联合分布”的方法。但 CutMix 的作者认为,简单地将两张图片叠加是不自然的,对于分类任务虽然有一定的效果提升,但对检测等下游任务,都会导致性能的下降。


CutMix 与 Mixup 的不同之处在于,CutMix 通过裁取一张图片中的部分区域,以类似拼贴画的方式贴到另一张图片上,来实现图片的混合。其中,裁切图片的面积比例按照 Beta(a,a)分布随机采样获得,a 为超参数。

640.png

在 MMClassification 中的应用


由于图像混合增强不同于其他的对单张图片作用的数据增强手段,其通常需要两张图片,因而使用方法也不同于其他增强方式在 pipeline 中配置的方式,而是在 model 的 train_cfg 中配置,如下:

model = dict(
    ...
    # alpha 为 Mixup 和 CutMix 中的超参数,与图片随机叠加的强度或随机裁剪的面积相关
    # prob 表示对一个 batch 的图像,采用对应增强方法的概率,总和不能超过 1.0
    train_cfg=dict(augments=[
        dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
        dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
    ]))

同样地,我们提供了一份 demo 脚本来展示效果(由于混合对象是随机选取的,可能出现图像自混合,从而没有效果的现象,可重复尝试):

import torch
import numpy as np
import mmcv
from mmcls.models.utils import Augments
img1 = mmcv.imresize(mmcv.imread("./kittens.jpg"), (224, 224))
img2 = mmcv.imresize(mmcv.imread("./house.jpg"), (224, 224))
imgs = [torch.from_numpy(img.transpose(2, 0, 1)) for img in (img1, img2)]
img_batch = torch.stack(imgs)
labels = torch.tensor([0, 1])
# ---- Mixup ----
mixup_cfg = dict(type='BatchMixup', alpha=0.8, num_classes=2, prob=1.0)
mixup = Augments([mixup_cfg])
img_batch_aug, label_batch_aug = mixup(img_batch, labels)
img1_aug = img_batch_aug[0].numpy().transpose(1, 2, 0).astype('uint8')
img2_aug = img_batch_aug[1].numpy().transpose(1, 2, 0).astype('uint8')
mmcv.imshow(img1_aug, win_name=f'mixup-{label_batch_aug[0]}')
mmcv.imshow(img2_aug, win_name=f'mixup-{label_batch_aug[1]}')
# ---- CutMix ----
cutmix_cfg = dict(type='BatchCutMix', alpha=1.0, num_classes=2, prob=1.0)
cutmix = Augments([cutmix_cfg])
img_batch_aug, label_batch_aug = cutmix(img_batch, labels)
img1_aug = img_batch_aug[0].numpy().transpose(1, 2, 0).astype('uint8')
img2_aug = img_batch_aug[1].numpy().transpose(1, 2, 0).astype('uint8')
mmcv.imshow(img1_aug, win_name=f'cutmix-{label_batch_aug[0]}')
mmcv.imshow(img2_aug, win_name=f'cutmix-{label_batch_aug[1]}')


文章来源:公众号【OpenMMLab】

2021-11-19 21:31


目录
相关文章
|
存储
将PC端的apk文件通过微信文件分享到手机,后缀名有.1
将PC端的apk文件通过微信文件分享到手机,后缀名有.1
688 0
|
SQL Oracle 关系型数据库
案例分析:你造吗?有个ORA-60死锁的解决方案
这段时间应用一直被一个诡异的 ORA-00060 的错误所困扰,众所周知,造成 ORA-00060 的原因是由于应用逻辑,而非 Oracle 数据库自己,之所以说诡异(“诡异”可能不准确,只能说这种场景,以前碰见的少,并未刻意关注),是因为这次不是常见的,由于读取数据顺序有交叉,导致ORA-0006.
2818 0
|
9月前
|
编解码 人工智能 API
飞桨x昇腾生态适配方案:12_动态OM推理
本文介绍了基于Ascend AI平台的OM模型动态推理方法,包括动态BatchSize、动态分辨率、动态维度及动态Shape四种场景,支持固定模式与自动设置模式。通过`ais_bench`工具实现推理,提供示例命令及输出结果说明,并解决常见问题(如环境变量未设置、输入与模型不匹配等)。此外,还提供了API推理指南及参考链接,帮助用户深入了解ONNX离线推理流程、性能优化案例及工具使用方法。
969 0
|
7月前
|
机器学习/深度学习 人工智能 机器人
Meta AI Research:虚拟/可穿戴/机器人三位一体的AI进化路径
本文阐述了我们对具身AI代理的研究——这些代理以视觉、虚拟或物理形式存在,使其能够与用户及环境互动。这些代理包括虚拟化身、可穿戴设备和机器人,旨在感知、学习并在其周围环境中采取行动。与非具身代理相比,这种特性使它们更接近人类的学习与环境交互方式。我们认为,世界模型的构建是具身AI代理推理与规划的核心,这使代理能够理解并预测环境、解析用户意图及社会背景,从而增强其自主完成复杂任务的能力。世界建模涵盖多模态感知的整合、通过推理进行行动规划与控制,以及记忆机制,以形成对物理世界的全面认知。除物理世界外,我们还提出需学习用户的心理世界模型,以优化人机协作。
661 3
|
6月前
|
传感器 运维 监控
AR眼镜在工业运维的场景应用和方案说明
AR眼镜通过虚实融合技术,革新工业运维模式。从设备巡检、故障维修到员工培训,AR实现远程协作、实时数据叠加与沉浸式教学,大幅提升效率与准确性,推动智能工厂发展。
|
7月前
|
安全 应用服务中间件 网络安全
在Linux环境部署Flask应用并启用SSL/TLS安全协议
至此,你的Flask应用应该能够通过安全的HTTPS协议提供服务了。记得定期更新SSL证书,Certbot可以帮你自动更新证书。可以设定cronjob以实现这一点。
535 10
|
10月前
|
人工智能 运维 安全
函数计算支持热门 MCP Server 一键部署
MCP(Model Context Protocol)自2024年发布以来,逐渐成为AI开发领域的实施标准。OpenAI宣布其Agent SDK支持MCP协议,进一步推动了其普及。然而,本地部署的MCP Server因效率低、扩展性差等问题,难以满足复杂生产需求。云上托管成为趋势,函数计算(FC)作为Serverless算力代表,提供一键托管开源MCP Server的能力,解决传统托管痛点,如成本高、弹性差、扩展复杂等。通过CAP平台,用户可快速部署多种热门MCP Server,体验高效灵活的AI应用开发与交互方式。
3869 10
|
安全 搜索推荐 网络安全
外贸网站应该如何搭建?
建立优质的外贸网站需要进行需求分析、域名选择、SSL证书部署、建站产品选择、便利性和响应式设计以及高质量内容SEO优化。选择合适的模板、部署SSL证书和高质量内容是关键。建站门槛低,节省成本,同时提升用户体验和搜索引擎可见性。
376 2
|
安全 Go 调度
解密Go语言并发模型:CSP与goroutine的魔法
在本文中,我们将深入探讨Go语言的并发模型,特别是CSP(Communicating Sequential Processes)理论及其在Go中的实现——goroutine。我们将分析CSP如何为并发编程提供了一种清晰、简洁的方法,并通过goroutine展示Go语言在处理高并发场景下的独特优势。
|
Java 关系型数据库 MySQL
新闻发布|基于JavaWeb实现新闻发布管理系统+论文+PPT(一)
新闻发布|基于JavaWeb实现新闻发布管理系统+论文+PPT
536 0

热门文章

最新文章