【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks

简介: 本文介绍了一种名为弹性权重合并(EWC)的方法,用于解决神经网络在学习新任务时遭受的灾难性遗忘问题,通过选择性地降低对旧任务重要权重的更新速度,成功地在多个任务上保持了高性能,且实验结果表明EWC在连续学习环境中的有效性。

阅读时间:2023-10-24

1 介绍

年份:2016
作者:James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, Demis Hassabis, Claudia Clopath, Dharshan Kumaran, Raia Hadsell,加州史丹佛大學史丹佛大學
期刊:Proceedings of the national academy of sciences
引用量:5449
这篇论文的主题是关于神经网络如何克服灾难性遗忘的问题,灾难性遗忘是神经网络在顺序学习任务时的一个限制。论文提出了一种称为弹性权重合并(EWC)的方法,可以使神经网络在学习新任务的同时记住旧任务。EWC会有选择地降低对先前学习任务重要的权重的学习速度,从而防止灾难性遗忘。作者通过在MNIST数据集上解决分类任务和顺序学习Atari 2600游戏的实验来证明EWC的有效性。论文将EWC与其他方法如L2正则化和dropout正则化进行了比较,结果表明EWC在保持旧任务高性能的同时学习新任务方面优于这些方法。论文解释了EWC的实现和合理性,包括如何约束重要参数和确定哪些权重对于每个任务是重要的。论文还讨论了哺乳动物大脑可能支持无灾难遗忘连续学习的神经机制。总的来说,这篇论文通过使用EWC提出了解决神经网络灾难性遗忘问题的方法。

2 创新点

  1. EWC方法:论文提出了一种名为弹性权重整合的算法,用于实现神经网络的连续学习。该算法根据先前学习任务中权重的重要性,减缓学习过程,从而保留旧任务的知识。
  2. 在MNIST数据集和Atari 2600游戏中的应用:论文通过在MNIST数据集上进行分类任务和在Atari 2600游戏中进行学习来展示EWC的有效性。结果表明,相比于L2正则化和dropout正则化等其他方法,EWC在学习新任务的同时能够维持旧任务的高性能。
  3. EWC的实施和正当性:论文解释了EWC的具体实施和合理性,包括对重要参数的约束和确定每个任务中哪些权重是重要的。论文还提到了哺乳动物大脑中支持连续学习而不发生灾难性遗忘的神经机制。

3 算法

(1)计算步骤

  • 计算每个权重在先前任务中的重要性:
    • 先前任务的损失函数:L_prev(θ),其中θ表示网络的权重。
    • Fisher信息矩阵:F_prev(θ) = E[∇²L_prev(θ)],其中∇²表示梯度的二阶导数。
    • 权重重要性:I_prev(θ) = F_prev(θ) * (θ - θ_prev)²,其中θ_prev表示在先前任务上训练后的权重。
  • 计算当前任务的损失函数:当前任务的损失函数:L_curr(θ)。
  • 计算正则化项并更新网络权重:
    • 正则化项:EWC_loss(θ) = L_curr(θ) + λ * Σ[I_prev(θ)], 其中λ是正则化项的权重,Σ表示对所有权重求和。
    • 更新网络权重:θ_new = argmin(θ)[EWC_loss(θ)]

(2)推理过程
训练了一个模型,其参数为 θ \theta θ,定义最小化以下损失函数来完成此操作:
$ \mathcal{L}(\theta) = \mathcal{L}_{new}(\theta) + \sum_{i=1}^{n} \frac{\lambda}{2} F_i (\theta_i - \theta_i^*)^2 $

其中, $ \mathcal{L}_{new}(\theta) $是新任务的损失函数,n 是先前任务的数量,$ F_i $​ 是 Fisher 信息矩阵的对角线元素, $ \theta_i^* $∗​ 是在先前任务i中找到的最优参数。 λ 是一个超参数,控制先前任务对新任务的影响。
Fisher 信息矩阵是 Hessian 矩阵的期望值,它衡量了损失函数对参数的二阶导数。 在 EWC 中,只计算对角线元素,因为它们提供了最大的信息,同时也更容易计算。Fisher 信息矩阵的对角线元素可以通过以下公式计算:
$$ F_{i,j} = \mathbb{E}_{x\sim D_i}[\frac{\partial \log p(y|x,\theta)}{\partial \theta_i} \frac{\partial \log p(y|x,\theta)}{\partial \theta_j}] $$

其中, Di​是先前任务i的数据分布, p(y∣x,θ)是模型在给定输入x 和参数 $ \theta $的情况下预测输出y的概率分布。
在每次学习新任务之前,需要计算 Fisher 信息矩阵和最优参数θi∗​。这可以通过在先前任务上运行梯度下降来实现,直到收敛为止。一旦计算出 Fisher 信息矩阵和最优参数,就可以使用 EWC 来学习新任务,同时保留先前任务的知识。
最后,可以使用以下公式计算 EWC 梯度:
$$ g_i = \nabla_{\theta\_i} \mathcal{L}_{new}(\theta) + \lambda \sum_{j=1}^{n} F_{i,j} (\theta_i - \theta_i^*) $$
其中,gi​是 EWC 梯度, $ \nabla_{\theta_i} \mathcal{L}_{new}(\theta) $是新任务的梯度。通过添加正则化项,EWC 可以确保新任务不会完全覆盖先前任务的知识,从而在连续学习中实现知识共享。

5 实验结果分析

(1)总结一
image.png

  • 使用纯随机梯度下降(SGD)训练这个任务序列会引发灾难性遗忘。
  • 图2A展示了两个不同任务的测试集性能。在训练从第一个任务切换到第二个任务时,任务B的性能迅速下降,而任务A的性能迅速上升。
  • 任务A的遗忘问题会随着更长的训练时间而进一步恶化。
  • 使用L2正则化不能解决这个问题,因为它对所有权重施加了相同的保护限制,导致在任务B上学习的能力受到限制。
  • 然而,使用EWC可以根据任务A中每个权重的重要性,使网络能够在不遗忘任务A的情况下很好地学习任务B。
  • 图2B展示了使用EWC和使用SGD与dropout正则化的所有任务的平均性能。可以看到EWC在旧任务上保持了高性能,并且仍然能够学习新任务。
  • 图2C展示了两个不同置换程度下网络深度的Fisher信息矩阵的相似性。任务越不相似,早期层的Fisher信息矩阵重叠越小。

(2)总结二

  • 当网络在两个非常相似的任务上训练(两个MNIST版本,只有少数像素被重排),这两个任务在整个网络中依赖于相似的权重集
  • 当两个任务之间更不相似时,网络开始为两个任务分配单独的能力(即权重)。

在进行大量重排时,网络靠近输出的层确实被两个任务重复使用。这反映了重排使得输入对内容是非常不同的,但输出的内容(即类别标签)是共享的。
(3)总结三
EWC可以在要求更高的强化学习(RL)领域中支持连续学习。作者测试了在经典的Atari 2600游戏集上,将Deep Q Networks与EWC相结合的方法。实验中,通过使用EWC,能够学习多个游戏,而不会忘记以前学习的游戏 。与以前的RL方法相比,EWC利用了固定资源(即网络容量)的单个网络,并且计算开销较小。

6 代码

https://github.com/yashkant/Elastic-Weight-Consolidation

目录
相关文章
|
6月前
|
人工智能 算法 安全
【博士论文】基于局部中心量度的聚类算法研究(Matlab代码实现)
【博士论文】基于局部中心量度的聚类算法研究(Matlab代码实现)
208 0
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
435 12
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
1332 9
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks
本文提出了一种基于任务条件超网络(Hypernetworks)的持续学习模型,通过超网络生成目标网络权重并结合正则化技术减少灾难性遗忘,实现有效的任务顺序学习与长期记忆保持。
356 4
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
443 3
|
4月前
|
机器学习/深度学习 算法 机器人
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
【水下图像增强融合算法】基于融合的水下图像与视频增强研究(Matlab代码实现)
436 0
|
4月前
|
数据采集 分布式计算 并行计算
mRMR算法实现特征选择-MATLAB
mRMR算法实现特征选择-MATLAB
298 2
|
5月前
|
传感器 机器学习/深度学习 编解码
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
MATLAB|主动噪声和振动控制算法——对较大的次级路径变化具有鲁棒性
285 3
|
4月前
|
机器学习/深度学习 算法 机器人
使用哈里斯角Harris和SIFT算法来实现局部特征匹配(Matlab代码实现)
使用哈里斯角Harris和SIFT算法来实现局部特征匹配(Matlab代码实现)
232 8
|
4月前
|
机器学习/深度学习 算法 自动驾驶
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
基于导向滤波的暗通道去雾算法在灰度与彩色图像可见度复原中的研究(Matlab代码实现)
254 8

热门文章

最新文章