【博士每天一篇文献-算法】Memory aware synapses_ Learning what (not) to forget

简介: 本文介绍了一种名为“记忆感知突触”(Memory Aware Synapses, MAS)的终身学习方法,该方法通过无监督在线评估神经网络参数的重要性,并在新任务学习时对重要参数的更改进行惩罚,有效防止了旧任务知识的覆盖,实现了内存效率和性能提升,同时具有灵活性和通用性。

阅读时间:2023-12-13

1 介绍

年份:2018
作者:Rahaf Aljundi,丰田汽车欧洲公司研究员;阿卜杜拉国王科技大学(KAUST)助理教授;Marcus Rohrbach德国达姆施塔特工业大学多模式可靠人工智能教授
会议: Proceedings of the European conference on computer vision (ECCV)
引用量:1416
代码:https://github.com/wannabeOG/MAS-PyTorch
https://github.com/rahafaljundi/MAS-Memory-Aware-Synapses
Aljundi R, Babiloni F, Elhoseiny M, et al. Memory aware synapses: Learning what (not) to forget[C]//Proceedings of the European conference on computer vision (ECCV). 2018: 139-154.

image.png
鉴于模型容量有限而新信息无限,知识需要被选择性地保留或抹去。提出了一种新的方法,称为“记忆感知突触”(Memory Aware Synapses, MAS),该算法不仅以在线方式计算网络参数的重要性,而且以无监督的方式适应网络测试的数据。当学习新任务时,对重要参数的更改可以受到惩罚,有效防止与以前任务相关的知识被覆盖。构建了一个能够适应权重重要性的持续系统,以系统需要记住的内容。我们的方法需要恒定的内存量,并具有我们上面列出的主要期望的终身学习特性,同时实现了最先进的性能。

2 创新点

  1. 记忆感知突触(MAS)方法:提出了一种新的终身学习方法,能够选择性地保留或抹去知识,以适应不断变化的学习任务和有限的模型容量。
  2. 无监督在线参数重要性评估:MAS能够在没有标签数据的情况下,在线地评估神经网络参数的重要性,这一点与传统的依赖于损失函数的方法不同。
  3. 基于输出函数敏感性的权重调整:MAS通过评估输出函数对参数变化的敏感性来计算参数的重要性,而不是依赖于损失函数的梯度,这避免了在损失函数局部最小值处梯度接近零的问题。
  4. 与Hebb学习规则的联系:展示了MAS方法与Hebb学习规则之间的联系,这是一种解释突触可塑性的生物学理论,表明MAS具有生物学上的合理性。
  5. 适应性权重更新:MAS能够根据未标记的测试数据更新参数的重要性权重,使得模型能够适应特定的测试条件和上下文。
  6. 实验验证:在多个任务和数据集上进行了实验验证,包括对象识别任务和学习预测<主体,谓语,对象>三元组的任务,证明了MAS方法的有效性。
  7. 性能提升:在标准终身学习设置和特定测试条件下,MAS都展现出了优于现有技术的性能,尤其是在减少灾难性遗忘方面。
  8. 内存效率:MAS方法在保持性能的同时,具有较低的内存消耗,这对于资源受限的应用场景尤为重要。
  9. 灵活性和通用性:MAS不仅限于特定的任务或数据类型,能够广泛应用于各种终身学习的场景。

3 相关研究

3.1 基于数据的方法

  1. Aljundi R, Chakravarty P, Tuytelaars T. Expert gate: Lifelong learning with a network of experts[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 3366-3375.
  2. Li Z, Hoiem D. Learning without forgetting[J]. IEEE transactions on pattern analysis and machine intelligence, 2017, 40(12): 2935-2947.
  3. Rannen A, Aljundi R, Blaschko M B, et al. Encoder based lifelong learning[C]//Proceedings of the IEEE international conference on computer vision. 2017: 1320-1328.
  4. Shmelkov K, Schmid C, Alahari K. Incremental learning of object detectors without catastrophic forgetting[C]//Proceedings of the IEEE international conference on computer vision. 2017: 3400-3409.

3.2 基于模型的方法

  1. Fernando C, Banarse D, Blundell C, et al. Pathnet: Evolution channels gradient descent in super neural networks[J]. arXiv preprint arXiv:1701.08734, 2017.
  2. Lee S W, Kim J H, Jun J, et al. Overcoming catastrophic forgetting by incremental moment matching[J]. Advances in neural information processing systems, 2017, 30.
  3. Zenke F, Poole B, Ganguli S. Continual learning through synaptic intelligence[C]//International conference on machine learning. PMLR, 2017: 3987-3995.(和本文相似)
  4. Kirkpatrick J, Pascanu R, Rabinowitz N, et al. Overcoming catastrophic forgetting in neural networks[J]. Proceedings of the national academy of sciences, 2017, 114(13): 3521-3526.(和本文相似)

但是这些方法有缺点:
(1)EWC基于Fisher信息矩阵对角线的近似来估计参数的重要性,这可能不完全准确反映参数的真实重要性。
(2)EWC为每个先前任务使用单独的惩罚项,这在实际应用中可能计算量大且不可行,因此需要对惩罚项进行简化。
(3)SI在新任务训练期间以在线方式估计重要性权重,依赖于批量梯度下降中的权重变化,这可能会高估权重的重要性。
(4)SI算法当从预训练网络开始学习时,一些权重可能在使用中没有大的变化,导致它们的重要性被低估。
(5)SI在训练过程中计算重要性权重,并在训练结束后固定这些权重,这限制了模型对测试数据的适应性。

4 算法

(1)MAS全局版本

  1. 参数重要性估计
  • 对于数据点 $ x_k $​,计算网络输出函数 $ F(x_k; \theta) $对参数 $ \theta_{ij} $ 的梯度 $ g_{ij}(x_k) = \frac{\partial F(x_k; \theta)}{\partial \theta_{ij}} $。
  • 为了简化计算,使用平方 ℓ2 范数的梯度:
    $$ g_{ij}(x_k) = \frac{\partial [|| F(x\_k; \theta) ||_2^2]}{\partial \theta_{ij}} $$
  1. 累积梯度以计算重要性权重
  • 使用公式计算参数 $ \theta_{ij} $ 的重要性权重:

$$ \Omega_{ij} = \frac{1}{N} \sum_{k=1}^{N} || g_{ij}(x_k) || $$
4. 学习新任务时的正则化

  • 新任务损失函数 L(θ) 包括正则化项,正则化项事是惩罚了对之前任务重要参数 $ \theta_{ij}^* $​的改变。
    $$L(\theta) = L_n(\theta) + \lambda \sum_{i,j} \Omega_{ij} (\theta_{ij} - \theta_{ij}^*)2 $$
    其中 λ 是正则化系数, $ \theta_{ij}^* $ 是先前任务中确定的“旧”网络参数。
  1. 更新重要性权重
  • 训练新任务后,根据之前计算的 Ω 更新重要性矩阵 Ω。

(2)MAS局部版本
局部版本的MAS方法不是考虑整个网络学习到的函数F,而是将其分解为一系列对应于网络每层的函数Fl。通过局部地保留每层给定其输入的输出,可以保留全局函数F。
其中参数重要性可以通过神经元激活的相关性来衡量:
$$\Omega{ij} = \frac{1}{N} \sum_{k=1}^{N} y_{i}^k \cdot y_{j}^k $$
对于 ReLU 激活函数,简化为:
$$ g_{ij}(x_k) = 2 \cdot y_{i}^k \cdot y_{j}^k $$
其中 $ y_{i}^k $​ 和 $ y_{j}^k $​分别是输入 $ x_k $​对应的第i个和第j个神经元的激活值。
(3优缺点:
全局MAS方法
优点:

  1. 全面性:全局MAS考虑整个网络学习到的函数,从而评估参数对整体性能的影响,这有助于捕捉不同层级之间的相互作用。
  2. 精确性:通过计算整个网络输出的梯度,全局MAS可以更精确地评估参数的重要性。
  3. 适应性:能够适应不同的数据分布,因为它是基于网络最终输出的敏感度来评估参数的重要性。
  4. 通用性:适用于任何类型的数据和任务,因为它不依赖于特定层级的激活模式。

缺点:

  1. 计算成本:可能需要更多的计算资源,因为它需要对整个网络的输出函数进行梯度计算。
  2. 复杂性:实现起来可能比局部MAS更复杂,因为它涉及到整个网络的梯度传播。

局部MAS方法(基于Hebb理论)
优点:

  1. 计算效率:局部MAS通过仅考虑单层的激活来计算参数的重要性,这减少了计算量。
  2. 实现简单:由于其简单性,局部MAS更容易实现和集成到现有的神经网络架构中。
  3. 快速适应:可以快速适应新任务或数据,因为它只需要局部的激活信息。
  4. 与生物学习机制的联系:局部MAS与Hebb学习规则有直接联系,这为理解人工神经网络中的学习过程提供了生物学上的见解。

缺点:

  1. 可能的不准确性:由于它只考虑局部信息,可能会忽略不同层级间参数的相互作用,导致对参数重要性的估计不够准确。
  2. 过度依赖局部激活:如果局部激活不能很好地代表整个网络的行为,那么局部MAS可能无法正确评估参数的重要性。
  3. 特定任务的局限性:可能在某些任务或数据分布上表现不如全局MAS,特别是当任务需要跨层级的信息整合时。

5 实验分析

(1)对比的模型
本文中提到的对比模型包括以下几种:

  1. Finetuning (FineTune): 这是一种基线方法,当学习新任务时,对网络的参数进行微调,以适应新任务的数据。
  2. Learning without Forgetting (LwF): 该方法在面对新任务时,通过记录先前任务的输出概率,并在新的损失函数中使用这些概率作为目标,以减少对旧知识的遗忘。
  3. Encoder Based Lifelong Learning (EBLL): EBLL在LwF的基础上,为每个任务学习一个浅层编码器,并应用变化惩罚和知识蒸馏损失来减少对先前任务的遗忘。
  4. Incremental Moment Matching (IMM): IMM通过对共享参数的变化施加L2惩罚来学习新任务,并在序列结束时通过第一或第二矩匹配合并模型。
  5. Elastic Weight Consolidation (EWC): EWC是首个提出在新任务学习时使用正则化网络参数的方法,它使用Fisher信息矩阵的对角线作为重要性度量。
  6. Synaptic Intelligence (SI): SI在训练新任务时,以在线方式估计重要性权重,并在训练后期任务时对先前任务的重要参数变化进行惩罚。
  7. Memory Aware Synapses (MAS): 本文提出的新方法,它通过无监督和在线的方式计算神经网络参数的重要性,基于预测输出函数对参数变化的敏感度。

(2)分类任务分析
实验使用基于三个数据集的两任务序列:MIT Scenes(室内场景分类)、Caltech-UCSD Birds(细粒度鸟类分类)和Oxford Flowers(细粒度花卉分类)。将MAS方法与其他几种终身学习(LLL)方法进行了比较,包括Finetune(微调)、Learning without Forgetting (LwF)、Encoder Based Lifelong Learning (EBLL)、Incremental Moment Matching (IMM)、Elastic Weight Consolidation (EWC)和Synaptic Intelligence (SI)。
image.png
每个任务的分类准确率。从图中可以看出,Finetune基线方法在新任务上性能较好,但在旧任务上性能下降显著,这表明了灾难性遗忘的问题。相比之下,MAS方法在所有任务上都显示出较高的准确率,并且与其他终身学习方法相比,其性能下降非常小。
image.png
为性能下降情况。FineTune方法在旧任务上的性能下降非常严重,这再次证实了其在连续学习中的不足。其他方法如LwF、EBLL、IMM、EWC和SI都显示出一定程度的遗忘,但遗忘程度较Finetune有所减轻。MAS方法在所有任务上的遗忘率最低,显示出最小的性能下降,这表明其在终身学习环境中对旧知识的保留效果最好。
(3)内存容量要求
image.png

  • MAS (Memory Aware Synapses):本文提出的方法,它在每个学习步骤中的内存需求是所有方法中最低的,这表明MAS在处理遗忘问题时非常内存高效。
  • SIEWC、**LwF **、EBLLIMM的内存需求随着任务序列的进行而逐渐增加。特别是IMM方法,其内存需求随任务数量线性增长,因为它需要存储所有任务的模型。

(4)敏感度分析
image.png
在一系列经过排列的MNIST任务中,平均性能和平均遗忘率随超参数λ的变化情况。

  1. 超参数λ的影响:λ是全局MAS方法中用于权衡新任务学习和旧任务遗忘的正则化项的权重。通过改变λ的值,可以观察到模型在新任务学习性能和旧任务遗忘之间的权衡。
  2. 性能与遗忘的平衡:从图中可以看出,当λ的值增加时,模型倾向于更多地保留旧任务的知识,从而减少遗忘。然而,如果λ过大,可能会对新任务的学习造成负面影响,因为模型过于保守,不愿意对参数进行足够的更新。

(5)检索任务性能分析
在6DS数据集的体育子集上,经过4个任务序列学习后,每种方法的平均精度均值(Mean Average Precision, MAP)的变化情况。MAP是信息检索和计算机视觉领域常用的性能指标,用于衡量模型对于检索任务的准确性。6DS数据集,全称为"Six Domains Dataset",是一个用于事实学习(Fact Learning)的中等规模数据集。它专门设计用于支持图像中事实的学习和检索任务,例如理解图像中的对象、属性和它们之间的关系。6DS数据集通常包含多种类别的图像,并且每个图像都与一个或多个事实相关联,这些事实以三元组的形式表示,包括主体(Subject)、谓语(Predicate)和对象(Object)。
image.png
与其他方法相比,MAS在体育子集上的MAP值下降较少,表明其在面对新任务时,能够更有效地保留先前任务的知识,减少灾难性遗忘。

6 思考

(1)MAS算法缺点

  1. 超参数调整:MAS算法引入了一个新的超参数λ,用于平衡新任务学习和旧任务遗忘之间的权衡。确定合适的λ值可能需要额外的调整工作,这可能在实际应用中增加复杂性。
  2. 计算资源:尽管MAS算法在内存效率方面有所优化,但在计算参数重要性时,尤其是在使用全局版本时,可能需要较多的计算资源,尤其是当数据集很大时。
  3. 适应性:MAS算法能够根据未标记的测试数据自适应地调整参数重要性,但在某些情况下,这种自适应性可能不如预期,特别是如果测试条件与训练条件差异很大时。
  4. 局部与全局方法的权衡:论文中提到了MAS的局部版本(基于Hebb理论)和全局版本。局部版本计算更快,但可能在准确性上有所折衷。选择使用哪种版本可能取决于具体应用的需求。
  5. 特定任务的泛化能力:MAS算法在论文中的任务上表现良好,但其泛化能力到其他类型的问题或任务上可能会有所不同,这需要在更广泛的任务和数据集上进行验证。
  6. 灾难性遗忘问题:尽管MAS算法旨在减少灾难性遗忘,但在学习大量新任务或非常不同的任务时,仍然可能面临知识遗忘的挑战。
  7. 实际应用的复杂性:在实际应用中,可能需要进一步的调整和优化,以适应特定的数据分布、任务特性或计算约束。
  8. 理论基础与实际效果的验证:MAS算法虽然在理论上受到Hebb学习规则的启发,但其在真实世界数据和任务上的有效性需要通过更多的实验来验证。
  9. 长期维护和更新:在长期学习和连续任务中,MAS算法可能需要持续的维护和更新,以适应不断变化的数据和任务需求。
  10. 对特定数据类型的依赖:MAS算法可能对某些类型的数据更加敏感,例如,如果数据具有特定的分布特性或噪声模式,算法的效果可能会受到影响。

(2)如何计算使用平方 ℓ2 范数的梯度,为什么要这么计算来表示参数的重要性?

  1. 简化计算:如果直接使用多维输出的每个维度来计算梯度,就需要对每个输出维度执行一次反向传播,这将需要与输出维度数量相同次数的计算。而使用平方 ℓ 2 \ell_2 ℓ2​范数,可以得到一个标量值,这意味着只需要一次反向传播,从而简化了计算过程。
  2. 避免梯度接近零的问题:在某些情况下,如果模型已经收敛到局部最小值,那么基于损失函数的梯度可能会非常小,这会导致参数重要性的估计不准确。使用输出函数的敏感度而不是损失函数的梯度可以避免这个问题,因为输出函数不太可能处于局部最小值。
  3. 与Hebb学习规则的联系:在文中提到的局部MAS方法中,使用平方 ℓ 2 \ell_2 ℓ2​ 范数的梯度与Hebb学习理论相联系。Hebb理论指出,如果两个神经元的激活同时发生,它们之间的突触连接应该被加强。在人工神经网络中,这可以被解释为如果两个神经元的激活值高度相关,那么它们之间的连接权重就更重要。平方 ℓ 2 \ell_2 ℓ2​范数的梯度可以反映这种相关性。
  4. 参数重要性的准确估计:通过平方 ℓ 2 \ell_2 ℓ2​范数的梯度,可以更准确地衡量参数对输出的影响。如果一个参数的小变化导致输出的平方 ℓ 2 \ell_2 ℓ2​范数有较大变化,那么这个参数对模型的预测就非常重要。
  5. 无监督和在线适应性:使用平方 ℓ 2 \ell_2 ℓ2​范数的梯度允许模型在无监督的情况下在线地适应新数据。这意味着模型可以在没有标签数据的情况下,根据输入数据动态调整参数的重要性权重。
  6. 提高效率:相比于基于损失的权重变化,使用平方 ℓ 2 \ell_2 ℓ2​范数的梯度可以更高效地估计参数的重要性,因为它只需要一次反向传播,并且可以利用所有可用的数据点来更新权重的重要性,而不需要额外的存储或处理。
目录
相关文章
|
1月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
39 12
|
1月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
65 9
|
1月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
66 3
|
6天前
|
算法 BI Serverless
基于鱼群算法的散热片形状优化matlab仿真
本研究利用浴盆曲线模拟空隙外形,并通过鱼群算法(FSA)优化浴盆曲线参数,以获得最佳孔隙度值及对应的R值。FSA通过模拟鱼群的聚群、避障和觅食行为,实现高效全局搜索。具体步骤包括初始化鱼群、计算适应度值、更新位置及判断终止条件。最终确定散热片的最佳形状参数。仿真结果显示该方法能显著提高优化效率。相关代码使用MATLAB 2022a实现。
|
6天前
|
算法 数据可视化
基于SSA奇异谱分析算法的时间序列趋势线提取matlab仿真
奇异谱分析(SSA)是一种基于奇异值分解(SVD)和轨迹矩阵的非线性、非参数时间序列分析方法,适用于提取趋势、周期性和噪声成分。本项目使用MATLAB 2022a版本实现从强干扰序列中提取趋势线,并通过可视化展示了原时间序列与提取的趋势分量。代码实现了滑动窗口下的奇异值分解和分组重构,适用于非线性和非平稳时间序列分析。此方法在气候变化、金融市场和生物医学信号处理等领域有广泛应用。
|
29天前
|
算法
基于模糊控制算法的倒立摆控制系统matlab仿真
本项目构建了一个基于模糊控制算法的倒立摆控制系统,利用MATLAB 2022a实现了从不稳定到稳定状态的转变,并输出了相应的动画和收敛过程。模糊控制器通过对小车位置与摆的角度误差及其变化量进行模糊化处理,依据预设的模糊规则库进行模糊推理并最终去模糊化为精确的控制量,成功地使倒立摆维持在直立位置。该方法无需精确数学模型,适用于处理系统的非线性和不确定性。
基于模糊控制算法的倒立摆控制系统matlab仿真
|
7天前
|
资源调度 算法
基于迭代扩展卡尔曼滤波算法的倒立摆控制系统matlab仿真
本课题研究基于迭代扩展卡尔曼滤波算法的倒立摆控制系统,并对比UKF、EKF、迭代UKF和迭代EKF的控制效果。倒立摆作为典型的非线性系统,适用于评估不同滤波方法的性能。UKF采用无迹变换逼近非线性函数,避免了EKF中的截断误差;EKF则通过泰勒级数展开近似非线性函数;迭代EKF和迭代UKF通过多次迭代提高状态估计精度。系统使用MATLAB 2022a进行仿真和分析,结果显示UKF和迭代UKF在非线性强的系统中表现更佳,但计算复杂度较高;EKF和迭代EKF则更适合维数较高或计算受限的场景。
|
9天前
|
算法
基于SIR模型的疫情发展趋势预测算法matlab仿真
该程序基于SIR模型预测疫情发展趋势,通过MATLAB 2022a版实现病例增长拟合分析,比较疫情防控力度。使用SIR微分方程模型拟合疫情发展过程,优化参数并求解微分方程组以预测易感者(S)、感染者(I)和移除者(R)的数量变化。![]该模型将总人群分为S、I、R三部分,通过解析或数值求解微分方程组预测疫情趋势。
|
9天前
|
算法 数据可视化 数据安全/隐私保护
基于LK光流提取算法的图像序列晃动程度计算matlab仿真
该算法基于Lucas-Kanade光流方法,用于计算图像序列的晃动程度。通过计算相邻帧间的光流场并定义晃动程度指标(如RMS),可量化图像晃动。此版本适用于Matlab 2022a,提供详细中文注释与操作视频。完整代码无水印。
|
24天前
|
算法 数据安全/隐私保护
基于LS算法的OFDM+QPSK系统信道估计均衡matlab性能仿真
基于MATLAB 2022a的仿真展示了OFDM+QPSK系统中最小二乘(LS)算法的信道估计与均衡效果。OFDM利用多个低速率子载波提高频谱效率,通过循环前缀克服多径衰落。LS算法依据导频符号估计信道参数,进而设计均衡器以恢复数据符号。核心程序实现了OFDM信号处理流程,包括加性高斯白噪声的加入、保护间隔去除、快速傅立叶变换及信道估计与均衡等步骤,并最终计算误码率,验证了算法的有效性。
43 2