【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding

简介: RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。

阅读时间:2023-12-27

1 介绍

年份:2018
作者:Arslan Chaudhry,DeepMind;Puneet K. Dokania,牛津大学
会议: ECCV
引用量:1138
Chaudhry A, Dokania P K, Ajanthan T, et al. Riemannian walk for incremental learning: Understanding forgetting and intransigence[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 532-547.
RWalk算法是一种增量学习框架,它通过结合EWC++(Elastic Weight Consolidation的高效版本)和修改版的Path Integral(PI)算法,改进的PI算法是使用KL散度。本文讨论了不同的采样策略,以存储先前任务数据集的一小部分代表性子集,这有助于网络回忆先前任务的信息,并学习区分当前和先前的任务。RWalk在MNIST和CIFAR-100数据集上的实验结果表明,在准确性方面取得了优越的结果,并且在遗忘和固执之间提供了更好的权衡。
image.png
image.png

2 创新点

  1. 提出新的连续学习评估指标
    • 提出新的评估指标,包括遗忘(Forgetting)和固执(Intransigence),这些指标专门用于量化和评估增量学习算法在保留旧知识的同时学习新知识的能力。
  2. 遗忘和固执的量化
    • 定义了遗忘度量(Forgetting Measure),通过比较模型在历史任务中的最大知识获取与当前状态的知识获取来量化遗忘。定义了固执度量(Intransigence Measure),通过比较增量学习模型与标准分类模型在新任务上的表现来量化模型学习新任务的能力。
  3. RWalk算法
    • 提出了RWalk算法,这是一种结合了EWC++和Path Integral(PI)的增量学习算法,具有理论上基于KL散度的视角。RWalk算法通过正则化条件似然分布来保留先前任务的知识,并使用优化路径上的重要性分数来更新知识,从而在遗忘和固执之间取得平衡。
  4. EWC++算法
    • 提出了EWC++算法,这是Elastic Weight Consolidation(EWC)的高效在线版本,通过使用移动平均来更新Fisher信息矩阵,减少了对内存的需求和计算复杂度。
  5. 优化路径上的重要性分数
    • 引入了基于优化路径的参数重要性分数,通过在整个训练过程中累积参数对损失函数变化的影响来量化参数的重要性。
  6. 代表性样本的采样策略
    • 研究了不同的采样策略,如均匀采样、基于平面距离的采样、基于熵的采样和基于特征均值(MoF)的采样,以存储先前任务的代表性样本,从而帮助模型在单头评估设置中更好地区分当前和先前的任务。
  7. 实验验证
    • 在MNIST和CIFAR-100数据集上进行了广泛的实验,验证了RWalk算法在准确性、遗忘和固执方面的优越性能,展示了其在实际增量学习任务中的应用潜力。
  8. 理论基础和分析
    • 提供了关于RWalk算法的理论基础和分析,包括Fisher信息矩阵在正则化中的作用,以及KL散度在Riemannian流形上的距离度量。

3 相关研究

  1. 动态网络扩展:一些方法通过为每个新任务动态扩展网络结构来解决遗忘问题,但这些方法随着任务数量的增加而变得不可扩展。
    • Rebuffi, S.A., Bilen, H., Vedaldi, A.: Learning multiple visual domains with residual adapters. In: NIPS (2017)
    • Rusu, A.A., Rabinowitz, N.C., Desjardins, G., Soyer, H., Kirkpatrick, J., Kavukcuoglu, K., Pascanu, R., Hadsell, R.: Progressive neural networks. arXiv preprint arXiv:1606.04671 (2016)
    • Terekhov, A.V., Montone, G., ORegan, J.K.: Knowledge transfer in deep block-modular neural networks. In: Conference on Biomimetic and Biohybrid Systems. pp. 268–279 (2015)
    • Yoon, J., Yang, E., Lee, J., Hwang, S.J.: Lifelong learning with dynamically expandable networks. In: ICLR (2018)
  2. EWC(Elastic Weight Consolidation):一种基于参数重要性的方法,通过正则化来减少遗忘。
    • Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences of the United States of America (PNAS) (2016)
  3. PI(Path Integral):另一种基于参数重要性的正则化方法,与EWC类似,但使用不同的参数重要性度量。
    • Zenke, F., Poole, B., Ganguli, S.: Continual learning through synaptic intelligence. In: ICML (2017)
  4. 在线EWC:与EWC++同时提出的在线版本的EWC,用于更有效的参数更新。
    • Schwarz, J., Luketina, J., Czarnecki, W.M., Grabska-Barwinska, A., Teh, Y.W., Pascanu, R., Hadsell, R.: Progress & compress: A scalable framework for continual learning. In: ICML (2018)
  5. Moment Matching:Lee等人使用矩匹配方法来组合所有任务的网络权重。
    • Lee, S.W., Kim, J.H., Ha, J.W., Zhang, B.T.: Overcoming catastrophic forgetting by incremental moment matching In: NIPS (2017)
  6. 贝叶斯框架:Nguyen等人通过贝叶斯框架强制执行模型参数分布的接近性。
    • Nguyen, C.V., Li, Y., Bui, T.D., Turner, R.E.: Variational continual learning. ICLR
  7. Gradient Episodic Memory:Lopez-Paz和Ranzato更新梯度,以确保先前任务的损失不会增加。
    • Lopez-Paz, D., Ranzato, M.: Gradient episodic memory for continuum learning. In: NIPS (2017)
  8. Continual Learning with Deep Generative Replay:Shin等人使用学习到的生成模型来重新生成先前任务的样本进行再训练。
    • Lopez-Paz, D., Ranzato, M.: Gradient episodic memory for continuum learning. In: NIPS (2017)
  9. iCaRL(Incremental Classifier and Representation Learning):Rebuffi等人提出的方法,使用基于激活的正则化和最近邻分类器,以及存储先前任务的样本。
    • Rebuffi, S.V., Kolesnikov, A., Lampert, C.H.: iCaRL: Incremental classifier and representation learning. In: CVPR (2017)

4 算法

在这里插入图片描述
图中是参数重要性随优化轨迹累积的情况,参数重要性是一个反映模型参数对损失函数变化敏感度的指标。在RWalk中,这个指标是基于优化路径上参数变化对损失函数的影响来计算的。优化轨迹指的是在训练过程中,模型参数随着优化算法(如梯度下降)更新的路径。每次参数更新导致损失函数的变化都会被记录下来,并用于计算该参数的重要性分数。小的参数更新如果导致损失函数显著下降,则该参数被认为更重要。

  1. 定义问题设置
    • 将任务流定义为一系列标签集,每个任务对应一个数据集,数据集中包含输入和对应的真实标签。
  2. 评估设置
    • 区分单头(single-head)和多头(multi-head)评估。单头评估在测试时不知道任务标识符,而多头评估在测试时知道任务标识符。
  3. 知识定义
    • 定义知识为模型学习到的内容,可以基于网络的输入输出行为或网络参数。
  4. 引入评估指标
    • 提出两个新的评估指标:遗忘度量(Forgetting Measure)和固执度量(Intransigence Measure),用于评估模型在历史任务和当前任务上的性能。
  5. RWalk的组合正则化项
    • 提出EWC++算法,使用移动平均来更新Fisher信息矩阵,减少了内存需求和计算复杂度。
    • 提出修改的PI算法,使用近似的KL散度来衡量输出分布之间的距离,而不是像原始PI那样使用欧几里得空间中的损失变化。
    • 将基于EWC++算法的Fisher信息的参数重要性和PI算法的基于优化路径的重要性分数结合起来,即将两个正则化项结合起来,形成RWalk的目标函数。

$ \tilde{L}_k(\theta) = L_k(\theta) + \lambda \sum_{i=1}^{P} (F_{\theta_{k-1}}^i + s_{k-1}^{t_0}(\theta_i))(\theta_i - \theta_{k-1}^i)^2 $
其中, $ L_k(\theta) $是第k个任务的损失函数, $ F_{\theta{k-1}}^i $​是第i个参数的Fisher信息, $ s_{k-1}^{t_0}(\theta_i) $是从训练开始到第k-1个任务结束时累积的参数重要性分数, λ是正则化系数。

  1. 处理固执问题
  • 通过存储先前任务的一小部分代表性样本,帮助模型在单头评估中区分当前和先前的任务,从而减少固执。
  1. 采样策略
  • 研究不同的采样策略,包括均匀采样、基于平面距离的采样、基于熵的采样和基于特征均值(MoF)的采样,以选择代表性样本。
  1. 实验验证
  • 在MNIST和CIFAR-100数据集上进行实验,验证RWalk算法在准确性、遗忘和固执方面的表现。

5 实验分析

(1)在MNIST和CIFAR数据集上的比较结果

  • Vanilla:没有使用任何正则化,直接训练网络的方法。
  • EWC:使用Elastic Weight Consolidation方法的基线,它是一种参数重要性正则化方法。
  • PI:使用Path Integral方法的基线,也是一种参数重要性正则化方法。
  • iCaRL-hb1:使用incremental classifier and representation learning方法的混合版本1,它使用基于激活的正则化和最近邻分类器。
  • RWalk (Ours):本文提出的算法,结合了EWC++和修改版的PI。

image.png
在多头评估设置中,除了Vanilla方法外,所有方法都展现出了几乎零遗忘和固执,平均准确率非常高,这表明在这种设置下增量学习问题似乎已经得到了解决。
引入先前任务的代表性样本('-S’标记的方法)显著减少了固执问题,提高了模型在单头评估设置中的性能。
不同的方法对正则化系数λ的敏感度不同。例如,EWC和PI对λ非常敏感,而RWalk由于其正则化项的累积和平均机制,对λ的敏感度较低。
(2)用多头评估和单头评估时模型性能的变化
image.png

  • 在多任务评估中,所有方法的准确率都很高,表明在知道当前任务标识符的情况下,模型能够很好地学习并识别当前任务的样本。
  • 在单任务评估中,如果不使用先前任务的样本,模型的准确率会显著下降。这是因为模型需要在没有任何提示的情况下识别样本可能属于的任何先前学习过的任务,这对模型是一个更大的挑战。
  • 当在单任务评估中引入先前任务的代表性样本时,模型的准确率有所提高。这表明代表性样本有助于减少固执问题,帮助模型更好地学习和泛化先前任务的知识。
  • 随着任务的增加,如果不用样本,性能可能会下降,特别是对于Vanilla(未正则化的)模型。
  • 使用MoF采样策略的代表性样本可以显著改善模型在单任务评估设置下的性能,使模型更接近多任务评估的性能水平。

(3)增量学习中遗忘(Forgetting)和固执(Intransigence)之间的关系
image.png
图表被划分为四个象限,每个象限代表不同的遗忘和固执组合:

  • 正面遗忘和正面固执(PBT, PFT):模型在新任务上表现良好,但对旧任务的记忆有所下降。
  • 正面遗忘和负面固执(PBT, NFT):模型在新任务上表现良好,但对旧任务的记忆保持得更好。
  • 负面遗忘和正面固执(NBT, PFT):模型在旧任务上保持了较好的记忆,但在新任务上表现不佳。
  • 负面遗忘和负面固执(NBT, NFT):模型在旧任务和新任务上都保持了较好的性能。

理想情况下,模型位于图表的左下角,即低遗忘和低固执区域,这表示模型在学习新任务时能够很好地保留旧知识,并且对新任务有很好的适应性。

(4)不同模型在增量学习任务上的性能变化
image.png
在只有少量样本的情况下,RWalk算法也能够提供稳定的性能,这表明其正则化策略有效地减少了对大量样本的依赖。
Vanilla模型虽然最终能够通过增加样本数量提高性能,但这在实际应用中可能不可行,因为存储和处理大量样本资源成本高昂。
RWalk算法通过正则化减少了对历史样本的依赖,使得模型在连续学习任务中更加高效和有效。
(5)不同采样策略时的性能差异
image.png
均匀采样(Uniform Sampling)、基于平面距离的采样(Plane Distance-based Sampling, PD)、基于熵的采样(Entropy-based Sampling)和基于特征均值的采样(Mean of Features, MoF)。

  • MoF采样策略的优势:MoF采样策略在两个数据集上都优于其他采样策略,这表明基于特征空间中样本与类别均值的接近程度来选择样本是一种有效的策略。
  • 简单采样策略的表现:尽管均匀采样是一种非常简单的采样方法,但在某些情况下,它的表现与更复杂的MoF采样策略相当接近,特别是在数据集相对简单的情况下(如MNIST)。
  • 复杂数据集的挑战:在更复杂的数据集(如CIFAR-100)上,采样策略的选择对性能的影响更为显著,MoF采样策略的优势更加明显。
  • 不同模型的稳健性:某些正则化方法可能对采样策略的选择不太敏感,表明这些方法在不同采样策略下都能保持较好的性能。

6 思考

(1)什么是固执问题?
固执问题(Intransigence)指的是模型在学习新任务时难以更新其知识,从而导致对新信息的适应性差。具体来说,当一个模型已经学习了一系列任务后,它可能会变得难以对新的数据或任务进行有效的学习或调整,这种现象就是固执。
固执问题通常与以下现象相关:

  1. 新任务学习困难:模型对新任务的数据不够敏感,难以捕捉新任务的特征和模式。
  2. 性能停滞不前:即使在新任务上继续训练,模型的性能提升有限或者根本没有提升。
  3. 泛化能力下降:模型可能在新任务上表现不佳,无法很好地泛化到未见过的数据。
  4. 先前知识固化:模型可能过于依赖于先前学习的知识,而难以对新知识进行整合。

(2)什么是单头评估和多头评估?
单头评估(Single-head Evaluation)和多头评估(Multi-head Evaluation)是增量学习(Incremental Learning)中用来评估模型性能的两种不同的测试设置:

  1. 多头评估(Multi-head Evaluation):
    • 在多头评估中,模型在测试时知道当前任务的标识符。这意味着在进行预测时,模型只需要关注当前任务的相关类别,而不需要考虑之前学习过的所有类别。
    • 例如,如果模型已经学习了10个类别,并且当前任务是第5个类别,那么在多头评估中,模型在测试时只需要从第5个类别的样本中进行预测。
    • 这种设置相对容易,因为模型在测试时可以专注于当前任务的类别,不需要进行跨任务的区分。
  2. 单头评估(Single-head Evaluation):
    • 单头评估是一种更具挑战性的设置,因为在测试时模型不知道当前任务的标识符。这意味着模型需要从所有先前学习过的类别中进行预测。
    • 继续上述例子,如果模型在第5个任务上进行单头评估,它需要能够预测0到4和当前任务5的类别,即所有它之前见过的类别。
    • 这种设置模拟了现实世界中的场景,模型需要在没有任何任务提示的情况下识别新样本的类别,这对模型的泛化能力和记忆能力提出了更高的要求。

两种评估设置的主要区别在于模型在测试时是否知道当前任务的上下文:

  • 多头评估更简单,因为它只要求模型识别当前任务的类别。
  • 单头评估更复杂,因为它要求模型能够识别所有先前学习过的任务的类别。

在实际应用中,单头评估通常更接近现实情况,因为它要求模型具有更好的遗忘抑制能力和对新任务的适应性。这也是为什么在评估增量学习算法时,单头评估被认为是一个更加严格和实际的测试方法。
(3)MoF采样策略是什么?
MoF(Mean of Features,特征均值)采样策略是一种用于增量学习中的技术,旨在从先前任务的数据集中选择具有代表性的样本。对于每个类别,MoF策略计算在特征空间中所有样本的均值。这个均值向量代表了该类别在特征空间中的中心位置。然后,MoF策略选择那些在特征空间中接近其类别均值的样本。这些样本被认为更能代表其类别的特征。
(4)论文中的PI算法就是SI算法,性能上没有提升,但是时间效率上提升了。并解决了遗忘(Forgetting)和固执(Intransigence)的固执问题。另一个角度去评价连续学习算法的好坏,非常大的启发。说明不用去卷遗忘准确率。

目录
相关文章
|
6天前
|
算法 JavaScript 前端开发
第一个算法项目 | JS实现并查集迷宫算法Demo学习
本文是关于使用JavaScript实现并查集迷宫算法的中国象棋demo的学习记录,包括项目运行方法、知识点梳理、代码赏析以及相关CSS样式表文件的介绍。
第一个算法项目 | JS实现并查集迷宫算法Demo学习
|
10天前
|
XML JavaScript 前端开发
学习react基础(1)_虚拟dom、diff算法、函数和class创建组件
本文介绍了React的核心概念,包括虚拟DOM、Diff算法以及如何通过函数和类创建React组件。
15 2
|
2月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
49 12
|
2月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
99 9
|
2月前
|
算法 Java
掌握算法学习之字符串经典用法
文章总结了字符串在算法领域的经典用法,特别是通过双指针法来实现字符串的反转操作,并提供了LeetCode上相关题目的Java代码实现,强调了掌握这些技巧对于提升算法思维的重要性。
|
2月前
|
算法 NoSQL 中间件
go语言后端开发学习(六) ——基于雪花算法生成用户ID
本文介绍了分布式ID生成中的Snowflake(雪花)算法。为解决用户ID安全性与唯一性问题,Snowflake算法生成的ID具备全局唯一性、递增性、高可用性和高性能性等特点。64位ID由符号位(固定为0)、41位时间戳、10位标识位(含数据中心与机器ID)及12位序列号组成。面对ID重复风险,可通过预分配、动态或统一分配标识位解决。Go语言实现示例展示了如何使用第三方包`sonyflake`生成ID,确保不同节点产生的ID始终唯一。
go语言后端开发学习(六) ——基于雪花算法生成用户ID
|
2天前
|
传感器 算法 C语言
基于无线传感器网络的节点分簇算法matlab仿真
该程序对传感器网络进行分簇,考虑节点能量状态、拓扑位置及孤立节点等因素。相较于LEACH算法,本程序评估网络持续时间、节点死亡趋势及能量消耗。使用MATLAB 2022a版本运行,展示了节点能量管理优化及网络生命周期延长的效果。通过簇头管理和数据融合,实现了能量高效和网络可扩展性。
|
29天前
|
算法 BI Serverless
基于鱼群算法的散热片形状优化matlab仿真
本研究利用浴盆曲线模拟空隙外形,并通过鱼群算法(FSA)优化浴盆曲线参数,以获得最佳孔隙度值及对应的R值。FSA通过模拟鱼群的聚群、避障和觅食行为,实现高效全局搜索。具体步骤包括初始化鱼群、计算适应度值、更新位置及判断终止条件。最终确定散热片的最佳形状参数。仿真结果显示该方法能显著提高优化效率。相关代码使用MATLAB 2022a实现。
|
29天前
|
算法 数据可视化
基于SSA奇异谱分析算法的时间序列趋势线提取matlab仿真
奇异谱分析(SSA)是一种基于奇异值分解(SVD)和轨迹矩阵的非线性、非参数时间序列分析方法,适用于提取趋势、周期性和噪声成分。本项目使用MATLAB 2022a版本实现从强干扰序列中提取趋势线,并通过可视化展示了原时间序列与提取的趋势分量。代码实现了滑动窗口下的奇异值分解和分组重构,适用于非线性和非平稳时间序列分析。此方法在气候变化、金融市场和生物医学信号处理等领域有广泛应用。
|
1月前
|
资源调度 算法
基于迭代扩展卡尔曼滤波算法的倒立摆控制系统matlab仿真
本课题研究基于迭代扩展卡尔曼滤波算法的倒立摆控制系统,并对比UKF、EKF、迭代UKF和迭代EKF的控制效果。倒立摆作为典型的非线性系统,适用于评估不同滤波方法的性能。UKF采用无迹变换逼近非线性函数,避免了EKF中的截断误差;EKF则通过泰勒级数展开近似非线性函数;迭代EKF和迭代UKF通过多次迭代提高状态估计精度。系统使用MATLAB 2022a进行仿真和分析,结果显示UKF和迭代UKF在非线性强的系统中表现更佳,但计算复杂度较高;EKF和迭代EKF则更适合维数较高或计算受限的场景。
下一篇
无影云桌面