自动数据增强论文及算法解读(附代码)

简介: 自动数据增强论文及算法解读(附代码)

论文题目


Abstract


数据增强是提高图像分类器精度的有效技术。但是当前的数据增强实现是手工设计的。在本论文中,我们提出了AutoAugment来自动搜索改进数据增强策略。我们设计了一个搜索空间,其中一个策略由许多子策略组成,每个小批量的每个图像随机选择一个子策略。子策略由两个操作组成,每个操作都是图像处理功能,例如平移,旋转或剪切,以及应用这些功能的概率。我们使用搜索算法来找到最佳策略,使得神经网络在目标数据集上产生最高的验证准确度。我们的方法在ImageNet上获得了83.5%的top1准确度,比之前83.1%的记录好0.4%。在CIFAR-10上,我们实现了1.5%的错误率,比之前的记录好了0.6%。扩充策略在数据集之间是可以相互转换的。在ImageNet上学到的策略也能在其他数据集上实现显著的提升。

Introduction


数据扩充是一种通过随机“扩充”数据来增加数据量和多样性的有效技术,是用来教授一个关于数据域不变性的模型,即使模型有良好的鲁棒性和平移不变性。

机器学习和计算机视觉领域的一大重点是设计更好的网络架构。人们很少注意寻找更好的数据增强方法,这种方法包含更多的不变性。例如,在ImageNet上,2012年引入的数据增强方法仍然是标准,只是有一些小的变化。即使对特定数据集进行了增强改进,它们通常也不会有效地转移到其他数据集。例如,训练期间图像的水平翻转是CIFAR-10上的一种有效数据增强方法,但在MNIST上则不是,因为这些数据集中存在不同的对称性。最近,人们提出了自动学习数据扩充的需求,这是一个尚未解决的重要问题。

我们使用搜索算法来寻找数据增强操作的最佳选择和顺序(如水平垂直翻转、平移、颜色归一化等等),这样训练神经网络可以获得最佳的验证精度。我们使用强化学习作为搜索算法,以此来训练和选择最佳的方法。

我们通过大量的实验表明在两种情况下AutoAugment可以获得很好的提升:1)AutoAugment直接应用于感兴趣的数据集,以找到最佳的扩充策略(AutoAugment-direct),2)学到的策略可以迁移到新的数据集(AutoAugment-transfer)。首先,对于直接应用,我们的方法在数据集上实现了最先进的准确性,例如CIFAR-10,减少的CIFAR-10,CIFAR-100,SVHN,减少的SVHN和ImageNet(没有附加数据)。其次,如果直接应用代价很高,迁移增强政策可能是一个不错的选择。我们展示了在一个任务上找到的策略可以在不同的模型和数据集中很好地泛化。例如,ImageNet上的策略可以显著改善各种FGVC数据集。具体结果见下表。

fd454d7cb2ad645efa09a143082157dd.png

AutoAugment: 直接在感兴趣的数据集上搜索最佳扩充策略


image.png

877e4d619e60dc55f499f7ab3302d96a.png

  • 图1  我们使用搜索方法(例如强化学习)来搜索更好的数据扩充策略的框架。控制器RNN从搜索空间预测扩充策略。训练一个具有固定结构的子网络,使其收敛到精度R。奖励R将与策略梯度方法一起使用,以更新控制器,使其能够随着时间的推移生成更好的策略。

搜索空间详细信息:


在我们设计的搜索空间中,一个策略包括5个子策略,每个子策略包括两个有序的图像运算。另外,每个运算与两个超参数相关:1)应用操作的概率,2)操作的幅度大小。

图2是搜索空间中包含5个子策略的样例。第一个子策略指定了ShearX应用,然后反转图像的像素。ShearX的概率是0.9,使用幅值大小是7/10。然后应用概率是0.8的翻转。反转图像像素操作不需要使用幅值大小。这些操作都是按照指定顺序进行的。

788bea141affb58a16ee25f4636ce445.png

图2 不同小批量数据增强结果


如上图所示,该策略有5个子策略。对于一个小批量中的每一幅图像,我们随机均匀地选择一个子策略来生成一幅变换后的图像来训练神经网络。每个子策略由2个操作组成,每个操作与两个数值关联:调用操作的概率和操作的幅值大小。操作的概率表示有可能调用某个操作,也有可能该操作不会应用于该小批量。但是,如果应用,它将以固定的幅值应用。我们强调了应用子策略的随机性,通过展示一幅图像如何在不同的小批量中进行不同的转换,即使使用相同的子策略也有可能采用不同的操作。如文中所述,在SVHN上,几何变换更多地是通过自动增强来选择的。可以看出为什么反转是SVHN上常见的选择操作,因为图像中的数字对该变换是不变的。


实验中使用的运算来自PIL。我们还使用了两个很有应用前途的数据增强方法:Cutout和SamplePairing。我们使用的图像操作运算有:ShearX/Y,TranslateX/Y,Rotate, AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Color, Brightness, Sharpness, Cutout, Sample Pairing。在我们的搜索空间中总共有16个图像操作,每个操作都具有默认的量级范围,将量级范围离散为10个值(均匀间距),这方便我们用离散搜索算法找到它们。类似的,我们也将操作的概率应用于11个值(均匀间距)来离散化。这使得搜索空间有(16*10*11)^2个可能性,我们的目标是同时找到这样的5个子策略以增加多样性。因此,包含5个子策略的搜索空间大约有(16*10*11)^10=2.9*10^32种可能性。

Search algorithm details


我们使用的搜索算法是强化学习。算法有两部分:RNN控制器,它是一个循环神经网络;训练算法是PPO。每一步中,控制器通过softmax进行决策,决策再输入下一步。控制器总共有30个决策,可以预测5个子策略,每个子策略包括两个操作,每个操作需要操作类型概率和使用幅值大小。

The training of controller RNN


控制器是用奖励信号训练的(学过强化学习算法的应该都知道),表示这个策略对模型泛化有多大提升。在实验中,我们留出了验证集来验证衡量子模型泛化性。子模型通过在训练集上(不包含验证集)应用5个子策略生成的增强的数据进行训练。对于小批量中的每个样本,从5个子策略随机选取一个来增强图像,然后在验证集上评估子模型来衡量精度,该精度用作RNN控制器的reward信号。在每个数据集上,控制器对大约15000个策略进行采样。

Architecture of controller RNN and training hyperparameters


控制器RNN是一层LSTM网络,每一层有100个隐藏单元,与每个架构决策相联系的两个卷积单元有2*5B个softmax预测(其中B通常为5)。控制器RNN的10B预测中的每一个都与概率相关。子网络的联合概率是这些10B最大值的所有概率的乘积。该联合概率用于计算控制器RNN的梯度。梯度通过子网络的验证精度进行缩放,以更新控制器RNN,从而控制器为性能不好的子网络分配低概率,为精度高的子网络分配高概率。我们强化学习算法采用了近端策略优化(PPO),学习率为0.00035。为了鼓励强化学习算法探索,我们还使用了权重为0.00001的熵惩罚。在我们的实现中,baseline function是以前奖励的指数移动平均值,权重为0.95。控制器的权重在-0.1和0.1之间均匀初始化。为了方便起见,我们选择使用PPO训练控制器,尽管之前的工作已经表明,其他方法(例如增强随机搜索和进化策略)也可以表现得很好。

最后,我们把最好的5个策略拼接为一个策略(包含25个子策略)。最后的一个策略包含25个子策略,用于训练每个数据集的模型。

Experiments and Results


子网络架构为小型Wide-ResNet-40-2(40层-widening为2)的模型,并训练120轮。选用这个模型是为了提高计算效率,因为每个子模型都是从头开始训练的。

下图中,我们展示了不同子模型神经网络架构下的测试集精度,并找到了权重衰减和学习率超参数,这些超参数为基线增强的常规训练提供了最佳验证集精度。

a0327277687326da5c3dc0dc1ddf02c0.png

图3   CIFAR-10、CIFAR-100和SVHN数据集上的测试集错误率(%)。越低越好。

fee79f41f7e60f34458447192fe88818.png

如上图所示,ImageNet上成功的策略之一。如本文所述,ImageNet上的大多数策略都使用基于颜色的转换。

57a8c90dc0efbd86e566d88e519ba43d.png

如上图所示,测试集上的Top-1/Top-5精度,越高的值表示性能越好。

Discussion


AutoAugment vs. other automated data augmentation methods

一种方法是GAN,生成器学习提出增强策略使得增强的图像可以骗过判别器。区别是,我们的方法试图直接优化分类精度,而他们的方法只是试图确保增强图像与当前训练图像相似。

fd2f119912357dc8d497ecf5a79d9061.png

如上图所示,对于这两种模型方法,自动增强会带来更高的改进(∼3.0%)。本文最后,作者认为本文的主要贡献在于我们的数据扩充方法和搜索空间的构建;不是在离散优化方法中,可以自己选择强化学习算法。

一些图像操作结果


3a2ee18cd53318a1491deeb0b1cab0da.png

2c56c4f8fa5bbc02ce0236c90b670a0d.png

相关文章
|
3天前
|
机器学习/深度学习 人工智能 算法
【图像版权】论文阅读:CRMW 图像隐写术+压缩算法
【图像版权】论文阅读:CRMW 图像隐写术+压缩算法
7 0
|
7天前
|
机器学习/深度学习 自然语言处理 算法
Python遗传算法GA对长短期记忆LSTM深度学习模型超参数调优分析司机数据|附数据代码
Python遗传算法GA对长短期记忆LSTM深度学习模型超参数调优分析司机数据|附数据代码
|
7天前
|
数据采集 机器学习/深度学习 算法
数据分享|WEKA关联规则挖掘Apriori算法在学生就业数据中的应用
数据分享|WEKA关联规则挖掘Apriori算法在学生就业数据中的应用
|
7天前
|
机器学习/深度学习 自然语言处理 算法
【大模型】关于减轻 LLM 训练数据和算法中偏差的研究
【5月更文挑战第6天】【大模型】关于减轻 LLM 训练数据和算法中偏差的研究
|
9天前
|
人工智能 算法 测试技术
论文介绍:进化算法优化模型融合策略
【5月更文挑战第3天】《进化算法优化模型融合策略》论文提出使用进化算法自动化创建和优化大型语言模型,通过模型融合提升性能并减少资源消耗。实验显示,这种方法在多种基准测试中取得先进性能,尤其在无特定任务训练情况下仍能超越参数更多模型。同时,该技术成功应用于创建具有文化意识的日语视觉-语言模型。然而,模型融合可能产生逻辑不连贯响应和准确性问题,未来工作将聚焦于图像扩散模型、自动源模型选择及生成自我改进的模型群体。[论文链接: https://arxiv.org/pdf/2403.13187.pdf]
112 1
|
14天前
|
机器学习/深度学习 数据采集 SQL
R语言K-Means(K均值聚类)和层次聚类算法对微博用户特征数据研究
R语言K-Means(K均值聚类)和层次聚类算法对微博用户特征数据研究
|
14天前
|
算法 数据可视化 数据挖掘
数据分享|R语言改进的K-MEANS(K-均值)聚类算法分析股票盈利能力和可视化
数据分享|R语言改进的K-MEANS(K-均值)聚类算法分析股票盈利能力和可视化
|
14天前
|
数据采集 存储 算法
数据分享|Weka数据挖掘Apriori关联规则算法分析用户网购数据
数据分享|Weka数据挖掘Apriori关联规则算法分析用户网购数据
|
15天前
|
数据采集 算法 安全
数据分享|R语言关联规则挖掘apriori算法挖掘评估汽车性能数据
数据分享|R语言关联规则挖掘apriori算法挖掘评估汽车性能数据
|
15天前
|
机器学习/深度学习 人工智能 运维
人工智能平台PAI 操作报错合集之请问Alink的算法中的序列异常检测组件,是对数据进行分组后分别在每个组中执行异常检测,而不是将数据看作时序数据进行异常检测吧
阿里云人工智能平台PAI (Platform for Artificial Intelligence) 是阿里云推出的一套全面、易用的机器学习和深度学习平台,旨在帮助企业、开发者和数据科学家快速构建、训练、部署和管理人工智能模型。在使用阿里云人工智能平台PAI进行操作时,可能会遇到各种类型的错误。以下列举了一些常见的报错情况及其可能的原因和解决方法。