MMClassificiation|实现数据增强的 N 种方法

简介: 众所周知,即使是目前最先进的神经网络模型,其本质上也是在利用一系列线性和非线性的函数去拟合目标输出。既然是拟合,当然越多的样本就能获得越准确的结果,这也是为什么现在训练神经网络所使用的数据规模越来越大的原因。

众所周知,即使是目前最先进的神经网络模型,其本质上也是在利用一系列线性和非线性的函数去拟合目标输出。


既然是拟合,当然越多的样本就能获得越准确的结果,这也是为什么现在训练神经网络所使用的数据规模越来越大的原因。

640.png

然而,在实际使用中,我们往往可能只有几千甚至几百份数据。面对神经网络数以 M 计的参数,很容易陷入过拟合的陷阱。


因为神经网络的收敛需要一个较长的训练过程,而这个过程中网络遇到的反反复复都是训练集的那几张图片,硬背都背下来了,自然很难学到什么能够泛化的特征。


一个自然的想法是,能不能用一张图片去生成一系列图片,从而成百上千倍地扩充我们的数据集?而这,也正是数据增强的目的之一。


如果说过拟合的问题好歹还能通过增加训练样本来缓解,那么下面这个问题就不得不依靠数据增强了——神经网络的“取巧”。


640.png

神经网络是没有常识的,因此它永远只会用最“方便”的方式区分两个类别。


假设我们要训练一个区分苹果和橘子的神经网络,但手上的数据只有红苹果和青橘子,那无论我们拍摄多少张照片,神经网络也只会简单地认为红色的就是苹果,青色的就是橘子。


这在实际使用中经常出现,拍摄的灯光、拍摄的角度等等,任何一个不起眼的区分点,都会被神经网络当做分类的依据。


接下来我们将列举当前分类方向研究中最常用的一系列数据增强手段和效果,并会给出在 MMClassification 中应用这些方法的具体例子,和大家分享。


本文内容

一个常见的误区

数据增强的6个常用方法

  随机翻转

  随机裁剪

  随机比例裁剪并缩放

  色彩抖动

  随机灰度化

  随机光照变换


1. 一个常见的误区



在介绍数据增强方法之前,希望能澄清一个常见的误区——一些人会认为,既然有这么多数据增强的方法,那么我一口气全堆到一起,是不是就能获得最好的增强效果?


答案是否定的,数据增强的目标并不是无脑地堆数据,而是尽可能地去覆盖原始数据无法覆盖不到,但现实生活中会出现的情况。


举个栗子,我们现在要训练一个用以区分道路上汽车种类的神经网络,那么图片的垂直翻转很大程度上就不是一个好的数据增强方法,毕竟现实中不太可能遇到汽车四轮朝上的情况,除了在盗梦空间。

640.png

再比如,我们希望训练一个区分水果是否成熟的神经网络,那么一些颜色相关的数据增强方法可能反而对训练结果有害。


这两个例子在提醒我们,有必要对数据增强方法有一个清晰的了解,然后针对自己的任务,选择合适的数据增强方法,才能充分发挥数据增强的作用。

2. 数据增强的6个常用方法



随机翻转

RandomFlip



随机翻转是一个非常常用的数据增强方法,包括水平和垂直翻转。其中,水平翻转是最常用的,但根据实际目标的不同,垂直翻转也可以使用。

640.png

在 MMClassificiation 中,大部分数据增强方法都可以通过修改 config 中的 pipeline 配置来实现。


这里我们提供了一份 python 代码,用来展示如上图所示的数据增强效果:

import mmcv
from mmcls.datasets import PIPELINES
# 数据增强配置,利用 Registry 机制创建数据增强对象
aug_cfg = dict(
    type='RandomFlip',
    flip_prob=0.5,           # 按 50% 的概率随机翻转图像
    direction='horizontal',  # 翻转方向为水平翻转
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
# 为了便于信息在预处理函数之间传递,数据增强项的输入和输出都是字典
img_info = {'img': img}
img_aug = aug(img_info)['img']
mmcv.imshow(img_aug)

关于配置选项更具体的说明,可以阅读 MMClassficiation 的官方文档。


官方文档链接 :

https://mmclassification.readthedocs.io/zh_CN/latest/api.html#mmcls.datasets.pipelines.RandomFlip


随机裁剪

RandomCrop



在图片的随机位置,按照指定的大小进行裁剪。


这种数据增强的方式能够在保留图像比例的基础上,移动图片上各区域在图片上的位置。

640.png

在 MMClassification 中,可使用以下配置:

# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
    type='RandomCrop',
    size=(384, 384),           # 裁剪大小
    padding = None,            # 边缘填充宽度(None 为不填充)
    pad_if_needed=True,        # 如果图片过小,是否自动填充边缘
    pad_val=(128, 128, 128),   # 边缘填充像素
    padding_mode='constant',   # 边缘填充模式
)


随机比例裁剪并缩放

RandomResizedCrop



这一方法目前几乎是 ImageNet 等通用图像数据集在进行分类网络训练时的标准增强手段


相较于 RandomCrop 死板地裁剪下固定尺寸的图片,RandomResizedCrop 会在一定的范围内,在随机位置按照随机比例裁剪图像,之后再缩放至统一的大小。


因此,图像会在比例上存在一定程度的失真。但这对分类来说不一定是件坏事,毕竟你并不会把一个稍扁一点的猫认成狗,而网络也能够通过这种增强学到更加接近本质的特征。


另外,因为是按比例的裁剪,这种增强手段也就对不同分辨率的图片输入更加友好。


640.png

在 MMClassification 中,可使用以下配置:


# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
    type='RandomResizedCrop',
    size=(384, 384),            # 目标大小
    scale=(0.08, 1.0),          # 裁剪图片面积占比限制(不得小于原始面积的 8%)
    ratio=(3. / 4., 4. / 3.),   # 裁剪图片长宽比例限制,防止过度失真
    max_attempts=10,            # 当长宽比和面积限制无法同时满足时,最大重试次数
    interpolation='bilinear',   # 图像缩放算法
    backend='cv2',              # 缩放后端,有时 'cv2'(OpenCV) 和 'pillow' 有微小差别
)

色彩抖动(ColorJitter)


上面我们介绍了两种基于裁剪的数据增强方法,接下来我们介绍一些对图像的色彩进行数据增强的方法


其中最常用的莫过于 ColorJitter,这种方法会在一定范围内,对图像的亮度(Brightness)、对比度(Contrast)、饱和度(Saturation)和色相(Hue)进行随机变换,从而模拟真实拍摄中不同灯光环境等条件的变化。

640.png

在 MMClassification 中,可使用以下配置:

# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
    type='ColorJitter',
    brightness=0.5,    # 亮度变化范围(0.5 ~ 1.5)
    contrast=0.5,      # 对比度变化范围(0.5 ~ 1.5)
    saturation=0.5,    # 饱和度变化范围(0.5 ~ 1.5)
    # 色相变换应用较少,目前 MMClassification 暂不支持 Hue 的增强
)

随机灰度化(RandomGrayscale)


按照一定概率,将图片转变为灰度图。这种增强方法消除了颜色的影响,在特定场景有所应用。

640.png

在 MMClassification 中,可使用以下配置:

# 此处只提供 cfg 选项,只需替换 RandomFlip 示例中对应部分,即可预览效果
aug_cfg = dict(
    type='RandomGrayscale',
    gray_prob=0.5,    # 按 50% 的概率随机灰度化图像
)

随机光照变换(Lighting)


Lighting 是在《ImageNet Classification with Deep Convolutional Neural Networks》一文中提出的一种针对图片光照的数据增强方法。


在这种方法中,我们首先在训练数据集中对所有图像的像素进行 PCA(主成分分析),从而获得 RGB 空间中的特征值和特征向量。那么这个特征向量代表了什么呢?此处,论文作者认为它代表了光照强度对图片像素的影响。毕竟虽然图像内容各种各样,但不管哪张图片的哪个位置,都不可避免地受到光照条件的影响。


既然特征向量代表了光照强度的影响,那么我们只要沿着特征向量的方向对图片的像素值做一些随机的加减,就能模拟不640.png同光照的图像了。


在 MMClassification 中,可使用以下配置。需要注意的是其中特征值和特征向量的设置,如果你的任务是在通用场景下的分类,那么可以直接沿用 ImageNet 的值;


而如果你的任务是在特殊光照环境下的,那么则需要采集不同光照强度下的图像,在自己的数据集上进行 PCA 来替代这里的设置。

import mmcv
from mmcls.datasets import PIPELINES
aug_cfg = dict(
    type='Lighting',
    eigval=[55.4625, 4.7940, 1.1475],      # 在 ImageNet 训练集 PCA 获得的特征值
    eigvec=[[-0.5675, 0.7192, 0.4009],     # 在 ImageNet 训练集 PCA 获得的特征向量
            [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203]],
    alphastd=2.0,  # 随机变换幅度,为了展示效果,这里设置较大,通常设置为 0.1
    to_rgb=True,   # 是否将图像转换为 RGB,mmcv 读取图像为 BGR 格式,为了与特征向量对应,此处转为 RGB
)
aug = PIPELINES.build(aug_cfg)
img = mmcv.imread("./kittens.jpg")
img_info = {'img': img}
img_aug = aug(img_info)['img']
# Lighting 变换得到的图像为 float32 类型,且超出 0~255 范围,为了可视化,此处进行限制
img_aug[img_aug < 0] = 0
img_aug[img_aug > 255] = 255
img_aug = img_aug.astype('uint8')[:, :, ::-1]   # 转回 BGR 格式
mmcv.imshow(img_aug)

以上介绍的数据增强方法只是常用方法的一部分,更多的数据增强方法,如多种方法的随机组合(AutoAugment、RandAugment)、多张图片的混合增强(MixUp、CutMix)等等,后续将为大家详细介绍。


另外,我们之后还会以一个当前流行的数据增强流程为例,详细介绍其使用方法,敬请期待~

文章来源:公众号【OpenMMLab】

2021-10-20 19:27

目录
相关文章
|
4月前
|
机器学习/深度学习 存储 Python
数据增强
【7月更文挑战第29天】
57 15
|
5月前
|
编解码 算法 计算机视觉
YOLOv8数据增强预处理方式详解:包括数据增强的作用,数据增强方式与方法
YOLOv8数据增强预处理方式详解:包括数据增强的作用,数据增强方式与方法
|
5月前
|
jenkins 测试技术 持续交付
利用C++增强框架的可测试性(Testability)
**C++框架可测试性提升策略**:通过模块化设计、依赖注入、使用Mock对象和Stub、编写清晰接口及文档、断言与异常处理、分离测试代码与生产代码、自动化测试,可以有效增强C++框架的可测试性。这些方法有助于确保代码正确性、健壮性,提高可维护性和可扩展性。示例包括使用类和接口实现模块化,通过构造函数进行依赖注入,以及利用Google Test和Google Mock进行断言和模拟测试。
80 1
|
机器学习/深度学习 算法框架/工具 Python
pyton数据增强
pyton数据增强
79 0
|
人工智能 程序员 C#
通过简单原理增强软件可靠性
通过简单原理增强软件可靠性
|
PyTorch 算法框架/工具
语义分割数据增强——图像和标注同步增强
其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。
684 0
|
机器学习/深度学习 人工智能 算法
数据增强方法汇总
数据增强方法汇总
251 0
|
机器学习/深度学习 存储 计算机视觉
【目标检测】常用数据增强从原理到实现
【目标检测】常用数据增强从原理到实现
287 0
|
机器学习/深度学习 存储 编解码
3D检测无痛涨点 | 上下文感知数据增强方法上下文感知数据增强方法CA-Aug助力3D!
3D检测无痛涨点 | 上下文感知数据增强方法上下文感知数据增强方法CA-Aug助力3D!
172 0
|
机器学习/深度学习 算法 测试技术
使用用测试时数据增强(TTA)提高预测结果(上)
使用用测试时数据增强(TTA)提高预测结果
515 0