【博士每天一篇文献-算法】Memory aware synapses_ Learning what (not) to forget

简介: 本文介绍了一种名为“记忆感知突触”(Memory Aware Synapses, MAS)的终身学习方法,该方法通过无监督在线评估神经网络参数的重要性,并在新任务学习时对重要参数的更改进行惩罚,有效防止了旧任务知识的覆盖,实现了内存效率和性能提升,同时具有灵活性和通用性。

阅读时间:2023-12-13

1 介绍

年份:2018
作者:Rahaf Aljundi,丰田汽车欧洲公司研究员;阿卜杜拉国王科技大学(KAUST)助理教授;Marcus Rohrbach德国达姆施塔特工业大学多模式可靠人工智能教授
会议: Proceedings of the European conference on computer vision (ECCV)
引用量:1416
代码:https://github.com/wannabeOG/MAS-PyTorch
https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses
Aljundi R, Babiloni F, Elhoseiny M, et al. Memory aware synapses: Learning what (not) to forget[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 139-154.

image.png
鉴于模型容量有限而新信息无限,知识需要被选择性地保留或抹去。提出了一种新的方法,称为“记忆感知突触”(Memory Aware Synapses, MAS),该算法不仅以在线方式计算网络参数的重要性,而且以无监督的方式适应网络测试的数据。当学习新任务时,对重要参数的更改可以受到惩罚,有效防止与以前任务相关的知识被覆盖。构建了一个能够适应权重重要性的持续系统,以系统需要记住的内容。我们的方法需要恒定的内存量,并具有我们上面列出的主要期望的终身学习特性,同时实现了最先进的性能。

2 创新点

  1. 记忆感知突触(MAS)方法:提出了一种新的终身学习方法,能够选择性地保留或抹去知识,以适应不断变化的学习任务和有限的模型容量。
  2. 无监督在线参数重要性评估:MAS能够在没有标签数据的情况下,在线地评估神经网络参数的重要性,这一点与传统的依赖于损失函数的方法不同。
  3. 基于输出函数敏感性的权重调整:MAS通过评估输出函数对参数变化的敏感性来计算参数的重要性,而不是依赖于损失函数的梯度,这避免了在损失函数局部最小值处梯度接近零的问题。
  4. 与Hebb学习规则的联系:展示了MAS方法与Hebb学习规则之间的联系,这是一种解释突触可塑性的生物学理论,表明MAS具有生物学上的合理性。
  5. 适应性权重更新:MAS能够根据未标记的测试数据更新参数的重要性权重,使得模型能够适应特定的测试条件和上下文。
  6. 实验验证:在多个任务和数据集上进行了实验验证,包括对象识别任务和学习预测<主体,谓语,对象>三元组的任务,证明了MAS方法的有效性。
  7. 性能提升:在标准终身学习设置和特定测试条件下,MAS都展现出了优于现有技术的性能,尤其是在减少灾难性遗忘方面。
  8. 内存效率:MAS方法在保持性能的同时,具有较低的内存消耗,这对于资源受限的应用场景尤为重要。
  9. 灵活性和通用性:MAS不仅限于特定的任务或数据类型,能够广泛应用于各种终身学习的场景。

3 相关研究

3.1 基于数据的方法

  1. Aljundi R, Chakravarty P, Tuytelaars T. Expert gate: Lifelong learning with a network of experts[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 3366-3375.
  2. Li Z, Hoiem D. Learning without forgetting[J]. IEEE transactions on pattern analysis and machine intelligence, 2017, 40(12): 2935-2947.
  3. Rannen A, Aljundi R, Blaschko M B, et al. Encoder based lifelong learning[C]//Proceedings of the IEEE international conference on computer vision. 2017: 1320-1328.
  4. Shmelkov K, Schmid C, Alahari K. Incremental learning of object detectors without catastrophic forgetting[C]//Proceedings of the IEEE international conference on computer vision. 2017: 3400-3409.

3.2 基于模型的方法

  1. Fernando C, Banarse D, Blundell C, et al. Pathnet: Evolution channels gradient descent in super neural networks[J]. arXiv preprint arXiv:1701.08734, 2017.
  2. Lee S W, Kim J H, Jun J, et al. Overcoming catastrophic forgetting by incremental moment matching[J]. Advances in neural information processing systems, 2017, 30.
  3. Zenke F, Poole B, Ganguli S. Continual learning through synaptic intelligence[C]//International conference on machine learning. PMLR, 2017: 3987-3995.(和本文相似)
  4. Kirkpatrick J, Pascanu R, Rabinowitz N, et al. Overcoming catastrophic forgetting in neural networks[J]. Proceedings of the national academy of sciences, 2017, 114(13): 3521-3526.(和本文相似)

但是这些方法有缺点:
(1)EWC基于Fisher信息矩阵对角线的近似来估计参数的重要性,这可能不完全准确反映参数的真实重要性。
(2)EWC为每个先前任务使用单独的惩罚项,这在实际应用中可能计算量大且不可行,因此需要对惩罚项进行简化。
(3)SI在新任务训练期间以在线方式估计重要性权重,依赖于批量梯度下降中的权重变化,这可能会高估权重的重要性。
(4)SI算法当从预训练网络开始学习时,一些权重可能在使用中没有大的变化,导致它们的重要性被低估。
(5)SI在训练过程中计算重要性权重,并在训练结束后固定这些权重,这限制了模型对测试数据的适应性。

4 算法

(1)MAS全局版本

  1. 参数重要性估计
  • 对于数据点 $ x_k $​,计算网络输出函数 $ F(x_k; \theta) $对参数 $ \theta_{ij} $ 的梯度 $ g_{ij}(x_k) = \frac{\partial F(x_k; \theta)}{\partial \theta_{ij}} $。
  • 为了简化计算,使用平方 ℓ2 范数的梯度:
    $$ g_{ij}(x_k) = \frac{\partial [|| F(x\_k; \theta) ||_2^2]}{\partial \theta_{ij}} $$
  1. 累积梯度以计算重要性权重
  • 使用公式计算参数 $ \theta_{ij} $ 的重要性权重:

$$ \Omega_{ij} = \frac{1}{N} \sum_{k=1}^{N} || g_{ij}(x_k) || $$
4. 学习新任务时的正则化

  • 新任务损失函数 L(θ) 包括正则化项,正则化项事是惩罚了对之前任务重要参数 $ \theta_{ij}^* $​的改变。
    $$L(\theta) = L_n(\theta) + \lambda \sum_{i,j} \Omega_{ij} (\theta_{ij} - \theta_{ij}^*)2 $$
    其中 λ 是正则化系数, $ \theta_{ij}^* $ 是先前任务中确定的“旧”网络参数。
  1. 更新重要性权重
  • 训练新任务后,根据之前计算的 Ω 更新重要性矩阵 Ω。

(2)MAS局部版本
局部版本的MAS方法不是考虑整个网络学习到的函数F,而是将其分解为一系列对应于网络每层的函数Fl。通过局部地保留每层给定其输入的输出,可以保留全局函数F。
其中参数重要性可以通过神经元激活的相关性来衡量:
$$\Omega{ij} = \frac{1}{N} \sum_{k=1}^{N} y_{i}^k \cdot y_{j}^k $$
对于 ReLU 激活函数,简化为:
$$ g_{ij}(x_k) = 2 \cdot y_{i}^k \cdot y_{j}^k $$
其中 $ y_{i}^k $​ 和 $ y_{j}^k $​分别是输入 $ x_k $​对应的第i个和第j个神经元的激活值。
(3优缺点:
全局MAS方法
优点:

  1. 全面性:全局MAS考虑整个网络学习到的函数,从而评估参数对整体性能的影响,这有助于捕捉不同层级之间的相互作用。
  2. 精确性:通过计算整个网络输出的梯度,全局MAS可以更精确地评估参数的重要性。
  3. 适应性:能够适应不同的数据分布,因为它是基于网络最终输出的敏感度来评估参数的重要性。
  4. 通用性:适用于任何类型的数据和任务,因为它不依赖于特定层级的激活模式。

缺点:

  1. 计算成本:可能需要更多的计算资源,因为它需要对整个网络的输出函数进行梯度计算。
  2. 复杂性:实现起来可能比局部MAS更复杂,因为它涉及到整个网络的梯度传播。

局部MAS方法(基于Hebb理论)
优点:

  1. 计算效率:局部MAS通过仅考虑单层的激活来计算参数的重要性,这减少了计算量。
  2. 实现简单:由于其简单性,局部MAS更容易实现和集成到现有的神经网络架构中。
  3. 快速适应:可以快速适应新任务或数据,因为它只需要局部的激活信息。
  4. 与生物学习机制的联系:局部MAS与Hebb学习规则有直接联系,这为理解人工神经网络中的学习过程提供了生物学上的见解。

缺点:

  1. 可能的不准确性:由于它只考虑局部信息,可能会忽略不同层级间参数的相互作用,导致对参数重要性的估计不够准确。
  2. 过度依赖局部激活:如果局部激活不能很好地代表整个网络的行为,那么局部MAS可能无法正确评估参数的重要性。
  3. 特定任务的局限性:可能在某些任务或数据分布上表现不如全局MAS,特别是当任务需要跨层级的信息整合时。

5 实验分析

(1)对比的模型
本文中提到的对比模型包括以下几种:

  1. Finetuning (FineTune): 这是一种基线方法,当学习新任务时,对网络的参数进行微调,以适应新任务的数据。
  2. Learning without Forgetting (LwF): 该方法在面对新任务时,通过记录先前任务的输出概率,并在新的损失函数中使用这些概率作为目标,以减少对旧知识的遗忘。
  3. Encoder Based Lifelong Learning (EBLL): EBLL在LwF的基础上,为每个任务学习一个浅层编码器,并应用变化惩罚和知识蒸馏损失来减少对先前任务的遗忘。
  4. Incremental Moment Matching (IMM): IMM通过对共享参数的变化施加L2惩罚来学习新任务,并在序列结束时通过第一或第二矩匹配合并模型。
  5. Elastic Weight Consolidation (EWC): EWC是首个提出在新任务学习时使用正则化网络参数的方法,它使用Fisher信息矩阵的对角线作为重要性度量。
  6. Synaptic Intelligence (SI): SI在训练新任务时,以在线方式估计重要性权重,并在训练后期任务时对先前任务的重要参数变化进行惩罚。
  7. Memory Aware Synapses (MAS): 本文提出的新方法,它通过无监督和在线的方式计算神经网络参数的重要性,基于预测输出函数对参数变化的敏感度。

(2)分类任务分析
实验使用基于三个数据集的两任务序列:MIT Scenes(室内场景分类)、Caltech-UCSD Birds(细粒度鸟类分类)和Oxford Flowers(细粒度花卉分类)。将MAS方法与其他几种终身学习(LLL)方法进行了比较,包括Finetune(微调)、Learning without Forgetting (LwF)、Encoder Based Lifelong Learning (EBLL)、Incremental Moment Matching (IMM)、Elastic Weight Consolidation (EWC)和Synaptic Intelligence (SI)。
image.png
每个任务的分类准确率。从图中可以看出,Finetune基线方法在新任务上性能较好,但在旧任务上性能下降显著,这表明了灾难性遗忘的问题。相比之下,MAS方法在所有任务上都显示出较高的准确率,并且与其他终身学习方法相比,其性能下降非常小。
image.png
为性能下降情况。FineTune方法在旧任务上的性能下降非常严重,这再次证实了其在连续学习中的不足。其他方法如LwF、EBLL、IMM、EWC和SI都显示出一定程度的遗忘,但遗忘程度较Finetune有所减轻。MAS方法在所有任务上的遗忘率最低,显示出最小的性能下降,这表明其在终身学习环境中对旧知识的保留效果最好。
(3)内存容量要求
image.png

  • MAS (Memory Aware Synapses):本文提出的方法,它在每个学习步骤中的内存需求是所有方法中最低的,这表明MAS在处理遗忘问题时非常内存高效。
  • SIEWC、**LwF **、EBLLIMM的内存需求随着任务序列的进行而逐渐增加。特别是IMM方法,其内存需求随任务数量线性增长,因为它需要存储所有任务的模型。

(4)敏感度分析
image.png
在一系列经过排列的MNIST任务中,平均性能和平均遗忘率随超参数λ的变化情况。

  1. 超参数λ的影响:λ是全局MAS方法中用于权衡新任务学习和旧任务遗忘的正则化项的权重。通过改变λ的值,可以观察到模型在新任务学习性能和旧任务遗忘之间的权衡。
  2. 性能与遗忘的平衡:从图中可以看出,当λ的值增加时,模型倾向于更多地保留旧任务的知识,从而减少遗忘。然而,如果λ过大,可能会对新任务的学习造成负面影响,因为模型过于保守,不愿意对参数进行足够的更新。

(5)检索任务性能分析
在6DS数据集的体育子集上,经过4个任务序列学习后,每种方法的平均精度均值(Mean Average Precision, MAP)的变化情况。MAP是信息检索和计算机视觉领域常用的性能指标,用于衡量模型对于检索任务的准确性。6DS数据集,全称为"Six Domains Dataset",是一个用于事实学习(Fact Learning)的中等规模数据集。它专门设计用于支持图像中事实的学习和检索任务,例如理解图像中的对象、属性和它们之间的关系。6DS数据集通常包含多种类别的图像,并且每个图像都与一个或多个事实相关联,这些事实以三元组的形式表示,包括主体(Subject)、谓语(Predicate)和对象(Object)。
image.png
与其他方法相比,MAS在体育子集上的MAP值下降较少,表明其在面对新任务时,能够更有效地保留先前任务的知识,减少灾难性遗忘。

6 思考

(1)MAS算法缺点

  1. 超参数调整:MAS算法引入了一个新的超参数λ,用于平衡新任务学习和旧任务遗忘之间的权衡。确定合适的λ值可能需要额外的调整工作,这可能在实际应用中增加复杂性。
  2. 计算资源:尽管MAS算法在内存效率方面有所优化,但在计算参数重要性时,尤其是在使用全局版本时,可能需要较多的计算资源,尤其是当数据集很大时。
  3. 适应性:MAS算法能够根据未标记的测试数据自适应地调整参数重要性,但在某些情况下,这种自适应性可能不如预期,特别是如果测试条件与训练条件差异很大时。
  4. 局部与全局方法的权衡:论文中提到了MAS的局部版本(基于Hebb理论)和全局版本。局部版本计算更快,但可能在准确性上有所折衷。选择使用哪种版本可能取决于具体应用的需求。
  5. 特定任务的泛化能力:MAS算法在论文中的任务上表现良好,但其泛化能力到其他类型的问题或任务上可能会有所不同,这需要在更广泛的任务和数据集上进行验证。
  6. 灾难性遗忘问题:尽管MAS算法旨在减少灾难性遗忘,但在学习大量新任务或非常不同的任务时,仍然可能面临知识遗忘的挑战。
  7. 实际应用的复杂性:在实际应用中,可能需要进一步的调整和优化,以适应特定的数据分布、任务特性或计算约束。
  8. 理论基础与实际效果的验证:MAS算法虽然在理论上受到Hebb学习规则的启发,但其在真实世界数据和任务上的有效性需要通过更多的实验来验证。
  9. 长期维护和更新:在长期学习和连续任务中,MAS算法可能需要持续的维护和更新,以适应不断变化的数据和任务需求。
  10. 对特定数据类型的依赖:MAS算法可能对某些类型的数据更加敏感,例如,如果数据具有特定的分布特性或噪声模式,算法的效果可能会受到影响。

(2)如何计算使用平方 ℓ2 范数的梯度,为什么要这么计算来表示参数的重要性?

  1. 简化计算:如果直接使用多维输出的每个维度来计算梯度,就需要对每个输出维度执行一次反向传播,这将需要与输出维度数量相同次数的计算。而使用平方 ℓ 2 \ell_2 ℓ2​范数,可以得到一个标量值,这意味着只需要一次反向传播,从而简化了计算过程。
  2. 避免梯度接近零的问题:在某些情况下,如果模型已经收敛到局部最小值,那么基于损失函数的梯度可能会非常小,这会导致参数重要性的估计不准确。使用输出函数的敏感度而不是损失函数的梯度可以避免这个问题,因为输出函数不太可能处于局部最小值。
  3. 与Hebb学习规则的联系:在文中提到的局部MAS方法中,使用平方 ℓ 2 \ell_2 ℓ2​ 范数的梯度与Hebb学习理论相联系。Hebb理论指出,如果两个神经元的激活同时发生,它们之间的突触连接应该被加强。在人工神经网络中,这可以被解释为如果两个神经元的激活值高度相关,那么它们之间的连接权重就更重要。平方 ℓ 2 \ell_2 ℓ2​范数的梯度可以反映这种相关性。
  4. 参数重要性的准确估计:通过平方 ℓ 2 \ell_2 ℓ2​范数的梯度,可以更准确地衡量参数对输出的影响。如果一个参数的小变化导致输出的平方 ℓ 2 \ell_2 ℓ2​范数有较大变化,那么这个参数对模型的预测就非常重要。
  5. 无监督和在线适应性:使用平方 ℓ 2 \ell_2 ℓ2​范数的梯度允许模型在无监督的情况下在线地适应新数据。这意味着模型可以在没有标签数据的情况下,根据输入数据动态调整参数的重要性权重。
  6. 提高效率:相比于基于损失的权重变化,使用平方 ℓ 2 \ell_2 ℓ2​范数的梯度可以更高效地估计参数的重要性,因为它只需要一次反向传播,并且可以利用所有可用的数据点来更新权重的重要性,而不需要额外的存储或处理。
目录
相关文章
|
4月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
87 12
|
4月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
282 9
|
4月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
99 3
|
18天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
24天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
4天前
|
供应链 算法 调度
排队算法的matlab仿真,带GUI界面
该程序使用MATLAB 2022A版本实现排队算法的仿真,并带有GUI界面。程序支持单队列单服务台、单队列多服务台和多队列多服务台三种排队方式。核心函数`func_mms2`通过模拟到达时间和服务时间,计算阻塞率和利用率。排队论研究系统中顾客和服务台的交互行为,广泛应用于通信网络、生产调度和服务行业等领域,旨在优化系统性能,减少等待时间,提高资源利用率。
|
11天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。
|
20天前
|
算法
基于GA遗传算法的PID控制器参数优化matlab建模与仿真
本项目基于遗传算法(GA)优化PID控制器参数,通过空间状态方程构建控制对象,自定义GA的选择、交叉、变异过程,以提高PID控制性能。与使用通用GA工具箱相比,此方法更灵活、针对性强。MATLAB2022A环境下测试,展示了GA优化前后PID控制效果的显著差异。核心代码实现了遗传算法的迭代优化过程,最终通过适应度函数评估并选择了最优PID参数,显著提升了系统响应速度和稳定性。
|
11天前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如&quot;How are you&quot;、&quot;I am fine&quot;、&quot;I love you&quot;等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。
|
17天前
|
算法
基于WOA鲸鱼优化的购售电收益与风险评估算法matlab仿真
本研究提出了一种基于鲸鱼优化算法(WOA)的购售电收益与风险评估算法。通过将售电公司购售电收益风险计算公式作为WOA的目标函数,经过迭代优化计算出最优购电策略。实验结果表明,在迭代次数超过10次后,风险价值收益优化值达到1715.1万元的最大值。WOA还确定了中长期市场、现货市场及可再生能源等不同市场的最优购电量,验证了算法的有效性。核心程序使用MATLAB2022a实现,通过多次迭代优化,实现了售电公司收益最大化和风险最小化的目标。