【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding

简介: RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。

阅读时间:2023-12-27

1 介绍

年份:2018
作者:Arslan Chaudhry,DeepMind;Puneet K. Dokania,牛津大学
会议: ECCV
引用量:1138
Chaudhry A, Dokania P K, Ajanthan T, et al. Riemannian walk for incremental learning: Understanding forgetting and intransigence[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 532-547.
RWalk算法是一种增量学习框架,它通过结合EWC++(Elastic Weight Consolidation的高效版本)和修改版的Path Integral(PI)算法,改进的PI算法是使用KL散度。本文讨论了不同的采样策略,以存储先前任务数据集的一小部分代表性子集,这有助于网络回忆先前任务的信息,并学习区分当前和先前的任务。RWalk在MNIST和CIFAR-100数据集上的实验结果表明,在准确性方面取得了优越的结果,并且在遗忘和固执之间提供了更好的权衡。
image.png
image.png

2 创新点

  1. 提出新的连续学习评估指标
    • 提出新的评估指标,包括遗忘(Forgetting)和固执(Intransigence),这些指标专门用于量化和评估增量学习算法在保留旧知识的同时学习新知识的能力。
  2. 遗忘和固执的量化
    • 定义了遗忘度量(Forgetting Measure),通过比较模型在历史任务中的最大知识获取与当前状态的知识获取来量化遗忘。定义了固执度量(Intransigence Measure),通过比较增量学习模型与标准分类模型在新任务上的表现来量化模型学习新任务的能力。
  3. RWalk算法
    • 提出了RWalk算法,这是一种结合了EWC++和Path Integral(PI)的增量学习算法,具有理论上基于KL散度的视角。RWalk算法通过正则化条件似然分布来保留先前任务的知识,并使用优化路径上的重要性分数来更新知识,从而在遗忘和固执之间取得平衡。
  4. EWC++算法
    • 提出了EWC++算法,这是Elastic Weight Consolidation(EWC)的高效在线版本,通过使用移动平均来更新Fisher信息矩阵,减少了对内存的需求和计算复杂度。
  5. 优化路径上的重要性分数
    • 引入了基于优化路径的参数重要性分数,通过在整个训练过程中累积参数对损失函数变化的影响来量化参数的重要性。
  6. 代表性样本的采样策略
    • 研究了不同的采样策略,如均匀采样、基于平面距离的采样、基于熵的采样和基于特征均值(MoF)的采样,以存储先前任务的代表性样本,从而帮助模型在单头评估设置中更好地区分当前和先前的任务。
  7. 实验验证
    • 在MNIST和CIFAR-100数据集上进行了广泛的实验,验证了RWalk算法在准确性、遗忘和固执方面的优越性能,展示了其在实际增量学习任务中的应用潜力。
  8. 理论基础和分析
    • 提供了关于RWalk算法的理论基础和分析,包括Fisher信息矩阵在正则化中的作用,以及KL散度在Riemannian流形上的距离度量。

3 相关研究

  1. 动态网络扩展:一些方法通过为每个新任务动态扩展网络结构来解决遗忘问题,但这些方法随着任务数量的增加而变得不可扩展。
    • Rebuffi, S.A., Bilen, H., Vedaldi, A.: Learning multiple visual domains with residual adapters. In: NIPS (2017)
    • Rusu, A.A., Rabinowitz, N.C., Desjardins, G., Soyer, H., Kirkpatrick, J., Kavukcuoglu, K., Pascanu, R., Hadsell, R.: Progressive neural networks. arXiv preprint arXiv:1606.04671 (2016)
    • Terekhov, A.V., Montone, G., ORegan, J.K.: Knowledge transfer in deep block-modular neural networks. In: Conference on Biomimetic and Biohybrid Systems. pp. 268–279 (2015)
    • Yoon, J., Yang, E., Lee, J., Hwang, S.J.: Lifelong learning with dynamically expandable networks. In: ICLR (2018)
  2. EWC(Elastic Weight Consolidation):一种基于参数重要性的方法,通过正则化来减少遗忘。
    • Overcoming catastrophic forgetting in neural networks. Proceedings of the National Academy of Sciences of the United States of America (PNAS) (2016)
  3. PI(Path Integral):另一种基于参数重要性的正则化方法,与EWC类似,但使用不同的参数重要性度量。
    • Zenke, F., Poole, B., Ganguli, S.: Continual learning through synaptic intelligence. In: ICML (2017)
  4. 在线EWC:与EWC++同时提出的在线版本的EWC,用于更有效的参数更新。
    • Schwarz, J., Luketina, J., Czarnecki, W.M., Grabska-Barwinska, A., Teh, Y.W., Pascanu, R., Hadsell, R.: Progress & compress: A scalable framework for continual learning. In: ICML (2018)
  5. Moment Matching:Lee等人使用矩匹配方法来组合所有任务的网络权重。
    • Lee, S.W., Kim, J.H., Ha, J.W., Zhang, B.T.: Overcoming catastrophic forgetting by incremental moment matching In: NIPS (2017)
  6. 贝叶斯框架:Nguyen等人通过贝叶斯框架强制执行模型参数分布的接近性。
    • Nguyen, C.V., Li, Y., Bui, T.D., Turner, R.E.: Variational continual learning. ICLR
  7. Gradient Episodic Memory:Lopez-Paz和Ranzato更新梯度,以确保先前任务的损失不会增加。
    • Lopez-Paz, D., Ranzato, M.: Gradient episodic memory for continuum learning. In: NIPS (2017)
  8. Continual Learning with Deep Generative Replay:Shin等人使用学习到的生成模型来重新生成先前任务的样本进行再训练。
    • Lopez-Paz, D., Ranzato, M.: Gradient episodic memory for continuum learning. In: NIPS (2017)
  9. iCaRL(Incremental Classifier and Representation Learning):Rebuffi等人提出的方法,使用基于激活的正则化和最近邻分类器,以及存储先前任务的样本。
    • Rebuffi, S.V., Kolesnikov, A., Lampert, C.H.: iCaRL: Incremental classifier and representation learning. In: CVPR (2017)

4 算法

在这里插入图片描述
图中是参数重要性随优化轨迹累积的情况,参数重要性是一个反映模型参数对损失函数变化敏感度的指标。在RWalk中,这个指标是基于优化路径上参数变化对损失函数的影响来计算的。优化轨迹指的是在训练过程中,模型参数随着优化算法(如梯度下降)更新的路径。每次参数更新导致损失函数的变化都会被记录下来,并用于计算该参数的重要性分数。小的参数更新如果导致损失函数显著下降,则该参数被认为更重要。

  1. 定义问题设置
    • 将任务流定义为一系列标签集,每个任务对应一个数据集,数据集中包含输入和对应的真实标签。
  2. 评估设置
    • 区分单头(single-head)和多头(multi-head)评估。单头评估在测试时不知道任务标识符,而多头评估在测试时知道任务标识符。
  3. 知识定义
    • 定义知识为模型学习到的内容,可以基于网络的输入输出行为或网络参数。
  4. 引入评估指标
    • 提出两个新的评估指标:遗忘度量(Forgetting Measure)和固执度量(Intransigence Measure),用于评估模型在历史任务和当前任务上的性能。
  5. RWalk的组合正则化项
    • 提出EWC++算法,使用移动平均来更新Fisher信息矩阵,减少了内存需求和计算复杂度。
    • 提出修改的PI算法,使用近似的KL散度来衡量输出分布之间的距离,而不是像原始PI那样使用欧几里得空间中的损失变化。
    • 将基于EWC++算法的Fisher信息的参数重要性和PI算法的基于优化路径的重要性分数结合起来,即将两个正则化项结合起来,形成RWalk的目标函数。

$ \tilde{L}_k(\theta) = L_k(\theta) + \lambda \sum_{i=1}^{P} (F_{\theta_{k-1}}^i + s_{k-1}^{t_0}(\theta_i))(\theta_i - \theta_{k-1}^i)^2 $
其中, $ L_k(\theta) $是第k个任务的损失函数, $ F_{\theta{k-1}}^i $​是第i个参数的Fisher信息, $ s_{k-1}^{t_0}(\theta_i) $是从训练开始到第k-1个任务结束时累积的参数重要性分数, λ是正则化系数。

  1. 处理固执问题
  • 通过存储先前任务的一小部分代表性样本,帮助模型在单头评估中区分当前和先前的任务,从而减少固执。
  1. 采样策略
  • 研究不同的采样策略,包括均匀采样、基于平面距离的采样、基于熵的采样和基于特征均值(MoF)的采样,以选择代表性样本。
  1. 实验验证
  • 在MNIST和CIFAR-100数据集上进行实验,验证RWalk算法在准确性、遗忘和固执方面的表现。

5 实验分析

(1)在MNIST和CIFAR数据集上的比较结果

  • Vanilla:没有使用任何正则化,直接训练网络的方法。
  • EWC:使用Elastic Weight Consolidation方法的基线,它是一种参数重要性正则化方法。
  • PI:使用Path Integral方法的基线,也是一种参数重要性正则化方法。
  • iCaRL-hb1:使用incremental classifier and representation learning方法的混合版本1,它使用基于激活的正则化和最近邻分类器。
  • RWalk (Ours):本文提出的算法,结合了EWC++和修改版的PI。

image.png
在多头评估设置中,除了Vanilla方法外,所有方法都展现出了几乎零遗忘和固执,平均准确率非常高,这表明在这种设置下增量学习问题似乎已经得到了解决。
引入先前任务的代表性样本('-S’标记的方法)显著减少了固执问题,提高了模型在单头评估设置中的性能。
不同的方法对正则化系数λ的敏感度不同。例如,EWC和PI对λ非常敏感,而RWalk由于其正则化项的累积和平均机制,对λ的敏感度较低。
(2)用多头评估和单头评估时模型性能的变化
image.png

  • 在多任务评估中,所有方法的准确率都很高,表明在知道当前任务标识符的情况下,模型能够很好地学习并识别当前任务的样本。
  • 在单任务评估中,如果不使用先前任务的样本,模型的准确率会显著下降。这是因为模型需要在没有任何提示的情况下识别样本可能属于的任何先前学习过的任务,这对模型是一个更大的挑战。
  • 当在单任务评估中引入先前任务的代表性样本时,模型的准确率有所提高。这表明代表性样本有助于减少固执问题,帮助模型更好地学习和泛化先前任务的知识。
  • 随着任务的增加,如果不用样本,性能可能会下降,特别是对于Vanilla(未正则化的)模型。
  • 使用MoF采样策略的代表性样本可以显著改善模型在单任务评估设置下的性能,使模型更接近多任务评估的性能水平。

(3)增量学习中遗忘(Forgetting)和固执(Intransigence)之间的关系
image.png
图表被划分为四个象限,每个象限代表不同的遗忘和固执组合:

  • 正面遗忘和正面固执(PBT, PFT):模型在新任务上表现良好,但对旧任务的记忆有所下降。
  • 正面遗忘和负面固执(PBT, NFT):模型在新任务上表现良好,但对旧任务的记忆保持得更好。
  • 负面遗忘和正面固执(NBT, PFT):模型在旧任务上保持了较好的记忆,但在新任务上表现不佳。
  • 负面遗忘和负面固执(NBT, NFT):模型在旧任务和新任务上都保持了较好的性能。

理想情况下,模型位于图表的左下角,即低遗忘和低固执区域,这表示模型在学习新任务时能够很好地保留旧知识,并且对新任务有很好的适应性。

(4)不同模型在增量学习任务上的性能变化
image.png
在只有少量样本的情况下,RWalk算法也能够提供稳定的性能,这表明其正则化策略有效地减少了对大量样本的依赖。
Vanilla模型虽然最终能够通过增加样本数量提高性能,但这在实际应用中可能不可行,因为存储和处理大量样本资源成本高昂。
RWalk算法通过正则化减少了对历史样本的依赖,使得模型在连续学习任务中更加高效和有效。
(5)不同采样策略时的性能差异
image.png
均匀采样(Uniform Sampling)、基于平面距离的采样(Plane Distance-based Sampling, PD)、基于熵的采样(Entropy-based Sampling)和基于特征均值的采样(Mean of Features, MoF)。

  • MoF采样策略的优势:MoF采样策略在两个数据集上都优于其他采样策略,这表明基于特征空间中样本与类别均值的接近程度来选择样本是一种有效的策略。
  • 简单采样策略的表现:尽管均匀采样是一种非常简单的采样方法,但在某些情况下,它的表现与更复杂的MoF采样策略相当接近,特别是在数据集相对简单的情况下(如MNIST)。
  • 复杂数据集的挑战:在更复杂的数据集(如CIFAR-100)上,采样策略的选择对性能的影响更为显著,MoF采样策略的优势更加明显。
  • 不同模型的稳健性:某些正则化方法可能对采样策略的选择不太敏感,表明这些方法在不同采样策略下都能保持较好的性能。

6 思考

(1)什么是固执问题?
固执问题(Intransigence)指的是模型在学习新任务时难以更新其知识,从而导致对新信息的适应性差。具体来说,当一个模型已经学习了一系列任务后,它可能会变得难以对新的数据或任务进行有效的学习或调整,这种现象就是固执。
固执问题通常与以下现象相关:

  1. 新任务学习困难:模型对新任务的数据不够敏感,难以捕捉新任务的特征和模式。
  2. 性能停滞不前:即使在新任务上继续训练,模型的性能提升有限或者根本没有提升。
  3. 泛化能力下降:模型可能在新任务上表现不佳,无法很好地泛化到未见过的数据。
  4. 先前知识固化:模型可能过于依赖于先前学习的知识,而难以对新知识进行整合。

(2)什么是单头评估和多头评估?
单头评估(Single-head Evaluation)和多头评估(Multi-head Evaluation)是增量学习(Incremental Learning)中用来评估模型性能的两种不同的测试设置:

  1. 多头评估(Multi-head Evaluation):
    • 在多头评估中,模型在测试时知道当前任务的标识符。这意味着在进行预测时,模型只需要关注当前任务的相关类别,而不需要考虑之前学习过的所有类别。
    • 例如,如果模型已经学习了10个类别,并且当前任务是第5个类别,那么在多头评估中,模型在测试时只需要从第5个类别的样本中进行预测。
    • 这种设置相对容易,因为模型在测试时可以专注于当前任务的类别,不需要进行跨任务的区分。
  2. 单头评估(Single-head Evaluation):
    • 单头评估是一种更具挑战性的设置,因为在测试时模型不知道当前任务的标识符。这意味着模型需要从所有先前学习过的类别中进行预测。
    • 继续上述例子,如果模型在第5个任务上进行单头评估,它需要能够预测0到4和当前任务5的类别,即所有它之前见过的类别。
    • 这种设置模拟了现实世界中的场景,模型需要在没有任何任务提示的情况下识别新样本的类别,这对模型的泛化能力和记忆能力提出了更高的要求。

两种评估设置的主要区别在于模型在测试时是否知道当前任务的上下文:

  • 多头评估更简单,因为它只要求模型识别当前任务的类别。
  • 单头评估更复杂,因为它要求模型能够识别所有先前学习过的任务的类别。

在实际应用中,单头评估通常更接近现实情况,因为它要求模型具有更好的遗忘抑制能力和对新任务的适应性。这也是为什么在评估增量学习算法时,单头评估被认为是一个更加严格和实际的测试方法。
(3)MoF采样策略是什么?
MoF(Mean of Features,特征均值)采样策略是一种用于增量学习中的技术,旨在从先前任务的数据集中选择具有代表性的样本。对于每个类别,MoF策略计算在特征空间中所有样本的均值。这个均值向量代表了该类别在特征空间中的中心位置。然后,MoF策略选择那些在特征空间中接近其类别均值的样本。这些样本被认为更能代表其类别的特征。
(4)论文中的PI算法就是SI算法,性能上没有提升,但是时间效率上提升了。并解决了遗忘(Forgetting)和固执(Intransigence)的固执问题。另一个角度去评价连续学习算法的好坏,非常大的启发。说明不用去卷遗忘准确率。

目录
相关文章
|
1月前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
数据结构与算法系列学习之串的定义和基本操作、串的储存结构、基本操作的实现、朴素模式匹配算法、KMP算法等代码举例及图解说明;【含常见的报错问题及其对应的解决方法】你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
【EMNLP2024】基于多轮课程学习的大语言模型蒸馏算法 TAPIR
阿里云人工智能平台 PAI 与复旦大学王鹏教授团队合作,在自然语言处理顶级会议 EMNLP 2024 上发表论文《Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning》。
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习(8)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之单双链表精题详解(9)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
存储 Web App开发 算法
2024重生之回溯数据结构与算法系列学习之单双链表【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构之单双链表按位、值查找;[前后]插入;删除指定节点;求表长、静态链表等代码及具体思路详解步骤;举例说明、注意点及常见报错问题所对应的解决方法
|
1月前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之栈和队列精题汇总(10)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第3章之IKUN和I原达人之数据结构与算法系列学习栈与队列精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之王道第2.3章节之线性表精题汇总二(5)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
IKU达人之数据结构与算法系列学习×单双链表精题详解、数据结构、C++、排序算法、java 、动态规划 你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
102 80
|
20天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
6天前
|
供应链 算法 调度
排队算法的matlab仿真,带GUI界面
该程序使用MATLAB 2022A版本实现排队算法的仿真,并带有GUI界面。程序支持单队列单服务台、单队列多服务台和多队列多服务台三种排队方式。核心函数`func_mms2`通过模拟到达时间和服务时间,计算阻塞率和利用率。排队论研究系统中顾客和服务台的交互行为,广泛应用于通信网络、生产调度和服务行业等领域,旨在优化系统性能,减少等待时间,提高资源利用率。