分布外泛化,「经验风险最小化ERM」真的是最好的算法么?

简介: 上海交通大学联合华为诺亚方舟实验室 AI 基础理论团队以及香港科技大学,提出了一种新的面向非独立同分布域泛化问题的评价指标 OoD-Bench,同时对 OoD 领域构建了一个统一的框架。

上海交通大学联合华为诺亚方舟实验室 AI 基础理论团队和香港科技大学近期发现:多维度 OoD 现象在多个数据集广泛存在。和之前论文 Domainbed 的研究结论 OoD 算法无法打败 ERM 不同,现有的 OoD 算法大部分只能在一个维度的 OoD 问题上打败 ERM 算法,在另一个维度的 OoD 问题上则无法打败 ERM 算法。


团队提出一种新的面向非独立同分布域泛化问题的评价指标 OoD-Bench 《OoD-Bench: Benchmarking and Understanding Out-of-Distribution Generalization Datasets and Algorithms》, 已投稿 。该工作提出了一种更好更全面的评价 Out-of-Distribution (OoD)算法的指标,同时对 OoD 领域构建了一个统一的框架。


微信图片_20211205135028.jpg

           论文链接:https://arxiv.org/pdf/2106.03721v1.pdf


研究背景


传统的机器学习算法,通常假设训练样本和测试样本来自同一概率分布 Independent and Identically Distributed (i.i.d.)。但是对于 Out-of-Distribution (OoD)场景,即训练样本的概率分布和测试样本的概率分布不同的情况,训练出的模型很难在目标域取得良好的表现。现有机器学习系统的可靠性已经在多个重要应用领域收到广泛关注比如医学图像处理,自动驾驶场景及安全系统。


尽管近年来已经有许多 OoD 算法被提出,如何理解训练数据以及更好的衡量 OoD 算法仍然是一项具有挑战性的任务。本文识别和度量出两种在现实生活中 OoD 数据集广泛存在的 correlation shift 和 diversity shift 数据偏移问题,通过大量实验分析现有 OoD 算法在这两类基准数据集上的表现。同时,本文将多种之前联系较少的不同领域算法与数据集统一在 OoD 研究的框架之下,为之后对人工智能内在机制的研究提供统一的基准和衡量指标。


微信图片_20211205135155.jpg

图 1:不同的数据集存在多种维度的数据偏移:左右两边分别是典型的包含 Diversity shift 和 Correlation shift 的数据集,此外,现实中有很多 OoD 数据集是同时包含了这两种数据偏移。


深度学习中的分布外 (OoD) 泛化是指模型在分布变化的场景下进行泛化的任务。我们假设在训练的过程中模型可以接触到相同任务但来自于不同环境与实验条件的数据集。OoD 泛化算法的目标是提取这些训练的不变性表征,假设这种不变性表征也能在未知的测试环境中保持。近年来,许多相关 OoD 算法被提出并声称在特定类型的基准测试集上超越了所有先前的工作。然而,最近的一项工作表明,目前大多数为 OoD 泛化设计的学习算法,仍然与经典的经验风险最小化方法 ERM 相当。


本文通过大量实验验证现有 OoD 算法的有效性,并揭示了一个看起来并不比 ERM 好很多的可能原因。事实表明,表现出分布变化的现有数据集通常可以分为具有不同特点的两类,如图 1 所示,大多数算法只能最多在其中一个类别数据集中超过 ERM。研究假设这种现象是由于两种不同的分布偏移的影响,即多样性迁移

(Diversity shift) 和相关性迁移(Correlation shift),而先前的工作往往只关注其中之一。


基于大量的实验和分析,本文为之后的 OoD 泛化研究提出了三点建议:

  1. OoD 算法应在两种类型的数据集上进行全面的评估,一种以多样性偏移 (Diversity shift) 为主,另一种以相关性偏移 (Correlation shift) 为主。这两种分布可以通过该研究的量化方法测量偏移;
  2. 在设计 OoD 算法之前可以先探究所要解决的 OoD 问题中分布偏移 (Distribution shift) 的性质,对于不同类型分布偏移的最佳处理方式可能不同;
  3. 设计能够更巧妙地捕捉现实世界分布变化的大规模数据集。该研究的实验与分析显示,人眼难以察觉的分布变化对于神经网络的可靠性也有很明显的影响。


方法概述


在监督式学习的设定下,不妨假设:输入变量 X 是由一系列潜变量决定的,可以把这些潜变量一分为二,记作 Z1 和 Z2,其中只有 Z1 才能决定目标变量 Y 。

给定训练和测试环境及其相关的概率密度函数 p 和 q,在假设不存在 label shift 的前提下,符合下列条件的 Z1 的存在使分布外泛化成为可能:


微信图片_20211205135304.jpg


另一方面,符合相反条件的 Z2 的存在使分布外泛化变得困难:


微信图片_20211205135319.jpg


Diversity shift 就是由满足 Z2 第一个条件的特征所引发的,而 correlation shift 则是由满足第二个条件的特征所引发的。


Diversity shift 的标志是仅出现在训练环境,没有出现在测试环境中的特征(或者相反)。例如在 PACS 里,照片中的色彩在速写中完全消失。可以把这些特征记作:

微信图片_20211205135339.jpg

于是 diversity shift 就被定义为:

微信图片_20211205135358.jpg

当 n=1 时,它们的含义可以被描绘如下:

微信图片_20211205135426.jpg

图 2:Diversity shift 和 correlation shift 的描绘。


Diversity shift 等于左图彩色区域面积总和的一半。Correlation shift 是在点集上的积分,每个被积分式的值可以被看作右图彩条高度之和的一半,乘上作为权重的两概率值乘积的平方根。

实际计算时,通过训练一个神经网络来提取计算所需的特征,以便进行估算。在各种数据集上的估算结果如下:微信图片_20211205135506.jpg

图 3:对于多种不同数据集度量 diversity shift 和 correlation shift。


这一结果与直觉相符:现有的大多数分布外泛化基准数据集都落在坐标轴之上或附近,意味着它们都只被两者之一所主导。对于存在不明分布偏移(distribution shift)的数据集,例如 ImageNet-A,ImageNet-R 和 ImageNet-V2,该研究的方法成功地将其所具有的偏移分解到 diversity 和 correlation 两个维度上,因此可以通过该研究的估计结果来针对不同的数据集选择合适的算法。


如接下来的 benchmark 结果所示,这类算法选择可能是关键的,因为大多数分布外泛化算法不能同时在两类数据集上都表现好,一类是被 diversity shift 所主导,另一类是被 correlation shift 所主导。


实验


该研究对 16 种不同算法 (ERM、GroupDRO、Mixup、 MLDG、DANN、CORAL、MMD、IRM、VREx、ARM、MTL、 SagNet、RSC、ANDMask、IGA、ERDG) 在 7 种不同 OoD 数据集 (PACS、 OfficeHome、Terra Incognita、WILDS-Camelyon17、Colored-MNIST、 NICO、CelebA) 上的表现进行了测试和分析。


实验结果


微信图片_20211205135542.jpg

表 1:ERM 和 OoD 算法在偏向 Diversity shift 数据集上的结果。


微信图片_20211205135607.jpg

表 2:ERM 和 OoD 算法在偏向 Correlation shift 数据集上的结果。


基准测试结果如表 1 和表 2 所示,除了平均准确度和标准误差,该研究还计算了每个算法相对于 ERM 的排名分数。具体来说,对于每个数据集 - 算法对,每个算法与 ERM 相比分别赋予分数: -1(低于),0(相当),1(高于)。最后将表中所列出的数据集分数相加得出排名分数。该排名分数反映了 Diversity shift 与 Correlation shift 的相对程度鲁棒性。从中可以看出,多数现有的 OoD 算法与 ERM 相比,并不能取得持续的性能提升。比如在 diversity shift 主导的数据集上 MMD、RSC、IGA 和 SagNet 的结果比 ERM 要高,但是在 correlation shift 主导的数据集上与 ERM 相比会低。
因此,该研究提出了衡量一个 OoD 算法的有效性,应该同时测试 diversity shift 和 correlation shift 两个维度的 OoD 性能。


可视化分析


微信图片_20211205135633.jpg

图 4: Attention 可视化效果图。


图 4 展示了不同算法所学到表征的可视化效果。由于篇幅所限,这里选择展示了两个具有代表性的算法: RSC 和 VREx,用于与 ERM 作比较。左边两列是来自 PACS 的图片,RSC 显示出比 ERM 和 VREx 更好的效果,因为 RSC 具有更广的关注范围,因此能捕捉到更多的全局结构信息而不是局部细节。右边的两列是来自 NICO 的样本,从图中可以看出,RSC 的注意力被非因果和局部特征(如背景和身体部分)吸引。相比之下,ERM 覆盖了更多的区域,包括感兴趣的目标位置,而 VREx 的注意力更加多样化,覆盖分散在整个图像中的不同区域。此外,注意力强度较弱,表明 VREx 不容易对虚假相关性过度自信。


微信图片_20211205135708.jpg

图 5: 对于 Colored MNIST 数据集,在不同的色彩分布下,估测 diversity shift 和 correlation shift。


图 5 是对多样性偏移与相关性偏移的估计。为了验证本文量化估计方法的鲁棒性,该研究对 Colored-MNIST 数据集进行消融实验研究,以检查是否可以产生稳定的结果,反映改变颜色变化时的预期趋势。


对比实验


微信图片_20211205135739.jpg

表 3:在只有一个训练环境的情况下对于 Colored MNIST 数据集测量 diversity shift。


该研究还将 OoD-Bench 与其他测量方法进行比较,表 3 显示了在 Colored-MNIST 数据集上的结果。结果发现,一般用于衡量分布之间差异的指标,比如 EMD 和 MMD,对 OoD 数据集中的相关性偏移不敏感,而 EMD 数据集同时对多样性变化不敏感。虽然 NI 可以在相关性偏移上产生比较结果,但它仍像 EMD 和 MMD 一样是一维的,无法区分数据集中存在的各种分布变化。该研究的方法提供了更稳定和可解释的比较结果。


总结


本文识别和量化 OoD 数据集两种主要的分布偏差: diversity shift 和 correlation shift,并阐明了一些真实世界的数据,未知分布变化的本质。此外,该研究还通过大量实验,展示了现有 OoD 算法的优势与劣势。结果表明,未来的算法必须同时在两种类型数据集进行综合评估,以便完整的评估 OoD 算法的性能。

相关文章
|
4月前
|
数据采集 机器学习/深度学习 算法
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
【python】python客户信息审计风险决策树算法分类预测(源码+数据集+论文)【独一无二】
|
6月前
|
机器学习/深度学习 算法 测试技术
如何应对缺失值带来的分布变化?探索填充缺失值的最佳插补算法
该文探讨了缺失值插补的不同方法,比较了它们恢复数据真实分布的效果。文章指出,处理插补尤其在小样本或复杂数据时是个挑战,需要选择能适应数据分布变化的方法。文中介绍了完全随机缺失(MCAR)、随机缺失(MAR)和非随机缺失(MNAR)三种机制,并以一个简单的例子展示了数据分布变化。文章通过比较均值插补、回归插补和高斯插补,强调了高斯插补在重现数据分布方面更优。评估插补方法时,不应仅依赖于RMSE,而应关注分布预测,使用如能量距离这样的指标。此外,即使在随机缺失情况下,数据分布也可能因模式变化而变化,需要考虑适应这些变化的插补方法。
195 2
|
机器学习/深度学习 算法 机器人
路径规划算法:基于蜻蜓分布优化的机器人路径规划算法- 附matlab代码
路径规划算法:基于蜻蜓分布优化的机器人路径规划算法- 附matlab代码
|
机器学习/深度学习 人工智能 算法
实用!50个大厂、987页大数据、算法项目落地经验教程合集
大数据、算法项目在任何大厂无论是面试还是工作运用都是非常广泛的,我们精选了50个百度、腾讯、阿里等大厂的大数据、算法落地经验甩给大家,千万不要做收藏党哦,空闲时间记得随时看看! 如果你没有大厂项目经验,对大厂算法、大数据的项目运用不了解建议你看看!
|
机器学习/深度学习 传感器 算法
路径规划算法:基于指数分布优化的机器人路径规划算法- 附matlab代码
路径规划算法:基于指数分布优化的机器人路径规划算法- 附matlab代码
|
编解码 算法 数据可视化
【高光谱图像的去噪算法】通过全变异最小化对受激拉曼光谱图像进行去噪研究(Matlab代码实现)
【高光谱图像的去噪算法】通过全变异最小化对受激拉曼光谱图像进行去噪研究(Matlab代码实现)
112 0
|
数据采集 监控 算法
【分布鲁棒、状态估计】分布式鲁棒优化电力系统状态估计研究[几种算法进行比较](Matlab代码实现)
【分布鲁棒、状态估计】分布式鲁棒优化电力系统状态估计研究[几种算法进行比较](Matlab代码实现)
104 0
|
算法 安全 调度
基于串行和并行ADMM算法在分布式调度中的应用(Matlab代码实现)
基于串行和并行ADMM算法在分布式调度中的应用(Matlab代码实现)
125 0
|
人工智能 算法
机器博弈 (三) 虚拟遗憾最小化算法
机器博弈 (三) 虚拟遗憾最小化算法
241 0
|
机器学习/深度学习 人工智能 开发框架
机器博弈 (二) 遗憾最小化算法
机器博弈 (二) 遗憾最小化算法
194 0
下一篇
DataWorks