【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task

简介: 本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。

阅读时间:2023-12-28

1 介绍

年份:2018
作者:Joan Serrà,Sony AI;Dídac Surís,PhD student, Columbia University;Marius Miron,Earth Species Project
会议:International conference on machine learning. PMLR
引用量:1064
代码:https://github.com/joansj/hat
Serra J, Suris D, Miron M, et al. Overcoming catastrophic forgetting with hard attention to the task[C]//International conference on machine learning. PMLR, 2018: 4548-4557.
本文提出了一种基于任务的硬注意力机制(Hard Attention to the Task, HAT),通过学习几乎二值的注意力向量来保持先前任务的信息,同时不影响当前任务的学习。通过门控任务嵌入学习硬注意力掩码,这些掩码定义了网络权重的更新约束。利用先前任务的注意力向量来调整梯度,保护先前任务中重要的权重。通过调整超参数,控制学习知识的稳定性和紧凑性,使得方法适用于不同的应用场景。实验结果显示,HAT在不同的实验设置中都能显著降低遗忘率。HAT机制还提供了监控网络行为的能力,例如评估网络容量使用情况和权重重用情况。此外,还可以利用硬注意力掩码进行网络剪枝,实现模型压缩。
image.png
image.png
image.png

2 创新点

  1. 硬注意力机制:提出了一种硬注意力机制,通过学习几乎二值的注意力向量,同时对每个任务进行学习,以保持先前任务的信息,而不会影响当前任务的学习。
  2. 任务嵌入的门控:通过门控任务嵌入,使用反向传播和批量随机梯度下降(SGD),动态创建和销毁跨层的路径,这些路径在学习新任务时可以被保留。
  3. 累积注意力向量:通过累积先前任务的注意力向量来调整梯度,从而保持对先前任务重要的权重的更新,防止对这些权重进行大的更新。
  4. 压缩性和稳定性控制:通过引入超参数,控制学习知识的稳定性和紧凑性,使得该方法也适用于在线学习或网络压缩应用。
  5. 无需记忆模块:与需要使用记忆模块的方法不同,HAT不需要记忆模块,而是通过学习单元级别的掩码来约束网络权重的更新。
  6. 简化的超参数:该方法只有两个超参数,直观地控制了学习知识的稳定性和紧凑性,而且调整这些超参数对于获得良好性能并不关键。
  7. 监控能力:硬注意力机制提供了监控模型行为的性,例如通过计算条件掩码来评估哪些权重获得高注意力值,以及跨任务的权重重用情况。
  8. 网络剪枝:利用硬注意力掩码来评估网络权重的重要性,并剪枝最不重要的权重,从而实现网络的压缩。

3 相关研究

  1. 复习策略(Rehearsal):
    • 论文 (1995). Catastrophic forgetting, rehearsal and pseudorehearsal.
  2. 记忆模块(Memory Modules):
    • 论文: (2017). iCaRL: incremental classifier and representation learning.
  3. 伪复习(Pseudorehearsal):
    • 论文:(1995). Catastrophic forgetting, rehearsal and pseudorehearsal.
  4. 生成网络替代记忆模块(Generative Networks as Memory Modules):
    • 论文:(2017). A strategy for an uncompromising incremental learner.
  5. 减少表示重叠(Reducing Representational Overlap):
    • 论文:(1991). Using semi-distributed representations to overcome catastrophic forgetting in connectionist networks.
  6. 结构正则化(Structural Regularization):
    • 论文:(2017). Improved multitask learning through synaptic intelligence
  7. 增量时刻匹配(Incremental Moment Matching, IMM):
    • 论文:(2017)Overcoming catastrophic forgetting by incremental moment matching.
  8. 弹性权重巩固(Elastic Weight Consolidation, EWC):
    • 论文: (2017)Overcoming catastrophic forgetting in neural networks
  9. 路径网络(PathNet):
    • 论文:PathNet: evolution channels gradient descent in super neural networks.
  10. 渐进式神经网络(Progressive Neural Networks, PNNs):
  • 论文:(2016)Progressive neural networks.
  1. 突触智能 (Synaptic Intelligence, SI):
  • 论文:(2017). Improved multitask learning through synaptic intelligence.
  1. 打包网络 (PackNet):
  • 论文:(2017). PackNet: adding multiple tasks to a single network by iterative pruning
  1. 动态可扩展网络 (Dynamically Expandable Networks, DEN):
  • 论文: (2018). Lifelong learning with dynamically expandable networks.

4 算法

4.1 关键点

(1)前向传播
数据从输入层逐层传递,每层的输出都会与任务相关的注意力向量相乘。注意力向量 $ a_{l}^t $​是通过 $ \sigma (e_{l}^t \cdot s) $计算的,其中 σ 是Sigmoid函数, $ e_l^t $​是任务嵌入,通常使用高斯分布_N_(0,1) 进行初始化,s 是正的缩放参数。在学习新任务时,利用所有先前任务的注意力向量来计算累积注意力向量,以保持对先前任务重要的权重。
使用Sigmoid函数和正的缩放参数的目的是来构建一个伪步函数,允许梯度流动。训练过程中逐渐增加缩放参数,使得注意力向量在测试时接近二值。
注意力向量 $ a_l^t $​ 与每层的输出 $ h_l $​进行元素级乘法操作,生成新的层输出 $ h_l=a_l^t\odot h_l $
(2)反向传播
在训练过程中,梯度会根据累积的注意力向量进行调整,以保护对先前任务重要的权重。梯度调整是通过最小化当前层和前一层的累积注意力向量的元素来实现。
为了解决在训练过程中嵌入梯度变化不大的问题,引入了嵌入梯度补偿机制。通过除以退火Sigmoid函数的导数并乘以期望的补偿量,调整梯度。
并且在损失函数中添加正则化项来促进注意力向量的稀疏性,从而为未来任务保留模型容量。
(3)补充
本文提出可以利用硬注意力机制监控模型的行为,例如评估网络容量的使用情况和权重的重用情况。此外,还可以使用硬注意力掩码来评估网络权重的重要性,并进行剪枝,从而实现模型压缩。

2.2 算法步骤

image.png

  1. 初始化
    • 在训练开始之前,任务嵌入 $ e_l^t $​需要被初始化。通常使用高斯分布N(0, 1) 进行初始化。
  2. 前向传播
    • 在每次训练迭代中,输入数据通过网络的每层进行前向传播,生成每层的输出 $ h_l $​。
  3. 计算门控任务嵌入
    • 对于每层l和每个任务t,计算门控任务嵌入注意力向量 $ a_l^t $​。这通常通过一个门控函数(如Sigmoid函数)和缩放参数s来实现:
      $ a_l^t = \sigma(s \cdot e_l^t) $
    • 其中 σ是Sigmoid函数,s是一个正的缩放参数。
  4. 元素级乘法
    • 将注意力向量 $ a_l^t $​与每层的输出 $ h_l $​进行元素级乘法操作,生成新的层输出 $ h_l = a_l^t \odot h_l $。
  5. 反向传播
    • 在损失函数计算后,通过反向传播算法计算每层的梯度。对于任务嵌入 $ e_l^t $​的梯度计算如下:
      $ \frac{\partial L}{\partial e_l^t} = \frac{\partial L}{\partial h_l} \odot a_l^t \odot (1 - a_l^t) ∂elt​$
    • 其中L是损失函数, ⊙表示元素级乘法。
  6. 梯度补偿
    • 为了增强梯度并促进 $ e_l^t $​的学习,应用梯度补偿。梯度补偿的公式如下:
      $ q_ l^t = s{\text{max}} \cdot (\cosh(e_lt) + 1) \cdot \frac{\partial L}{\partial e_l^t} $
    • 这里 $ s_{\text{max}} $​是一个超参数,用于控制梯度补偿的强度。
  7. 更新任务嵌入
    • 使用梯度下降算法更新任务嵌入 $ e_l^t $​:
      $ e_l^t \leftarrow e_l^t - \eta \cdot q_l^t $
    • 其中 η是学习率。
  8. 退火参数s
    • 在训练过程中,逐渐增加缩放参数s 退火,使得 $ a_l^t $​逐渐接近二值化,从而在测试阶段更稳定:
      $ s = \frac{1}{s_{\text{max}} + \left( \frac{s_{\text{max}} - 1}{s_{\text{max}}} \right) \cdot \left( \frac{b}{B} \right)} $
    • 其中b是当前批次索引,B是总批次数。

5 实验分析

(1)退火参数 s 对于任务嵌入 $ e_l^t $​ 的梯度影响
image.png
Compensated这条曲线展示了应用了梯度补偿后的梯度分布。梯度补偿的目的是调整梯度的分布,使得它们在 $ e_l^t $​的期望活动范围内有更大的梯度值。
算法梯度更新中,通过除以退火Sigmoid函数的导数并乘以期望的补偿量,HAT算法调整梯度的分布和幅度,使得 $ e_l^t $​的梯度在训练过程中能够更有效地更新。
(2)平均遗忘比率 ρ≤t 随任务数 t的变化情况
image.png
该比率越接近0,表示模型遗忘旧任务的程度越低。HAT方法在所有任务数t$\ge$2 的情况下,遗忘比率始终低于其他基线方法。这表明HAT方法在减少灾难性遗忘方面非常有效。随着任务数的增加,HAT方法的遗忘比率增加幅度较小,显示出较好的稳定性和鲁棒性。

(3)平均遗忘比率
image.png
在学习第二个任务( ρ ≤ 2)和最后一个任务( ρ ≤ 8)之后的平均遗忘比率。这些数据是通过10次运行实验得出的,括号内是标准差。

  1. HAT方法的优势
    • HAT(硬注意力到任务)方法在学习第二个任务后的平均遗忘比率 ρ≤ 2 为 -0.02,是所有方法中最低的。这表明HAT在初始任务转换时遗忘最少。
    • 在学习最后一个任务后,HAT的 ρ ≤ 8 为 -0.06,同样是所有方法中最低的。这表明即使在多次任务转换后,HAT也能保持较低的遗忘率。
  2. 基线方法的表现
    • 其他基线方法如LFL(Less-Forgetting Learning)、LWF(Learning Without Forgetting)、SGD(随机梯度下降)等在学习第二个任务后遗忘比率较高,特别是LFL和LWF,其 ρ ≤ 2 分别为 -0.73 和 -0.14。
    • 在学习最后一个任务后,这些方法的遗忘比率进一步增加,例如LFL的 ρ ≤ 8为 -0.92,LWF的 ρ ≤ 8为 -0.80。
  3. 遗忘比率的减少
    • HAT方法在 ρ ≤ 2和 ρ ≤ 8 上的低遗忘比率表明其在学习新任务时能够有效地保留旧任务的知识。
  4. 标准差的稳定性
    • HAT方法的标准差较低,表明其在不同实验运行中表现稳定,不受随机因素的影响。
  5. 其他方法的敏感性
    • 一些方法如LFL和PathNet显示出较高的标准差,这表明它们对超参数、初始化或数据集的变化较为敏感。

(4)超参数 $ s_{max} $​和c对平均遗忘比率 ρ ≤ 8 的影响。

image.png

  • $ s_{max} $​:控制门控函数的极性或“硬度”。较高的 $ s_{max} $​值会使门控函数更接近单位步函数,从而在训练过程中更稳定地保留先前任务的信息。
  • c:控制模型的紧凑性。较高的 c 值会促使模型学习更紧凑的表示,减少对当前任务的容量使用。
  • 随着 $x s_{max} $​ 值的增加,平均遗忘比率 ρ ≤ 8 通常会降低。这表明更高的 $ s_{max} $​有助于更好地保留先前任务的知识。然而, $ s_{max} $​ 过高会限制模型对新任务的适应能力,因此需要找到一个平衡点。
  • c 值的增加会导致 ρ ≤ 8 增加,这表明更高的 c值会使模型在学习新任务时遗忘更多的旧任务知识。c 值较低时,模型会使用更多的网络容量来学习当前任务,从而减少遗忘。

(5)顺序任务学习时,网络容量的使用情况
image.png

  • 曲线显示了在顺序学习多个任务时,网络如何动态地调整其容量的使用。这表明HAT算法能够根据每个任务的需求调整网络的活跃部分。
  • 虚线垂直线表示任务切换点,可以看到在这些点附近,网络容量的使用率会发生变化。这反映了模型在适应新任务时对网络容量的重新分配。
  • 在某些任务中,网络容量的使用率会降低,这表明HAT算法在学习新任务时能够保留一部分网络容量用于未来任务。这种保留有助于减少灾难性遗忘。
  • HAT算法不仅能够处理单个任务,还能够在多个任务之间灵活地调整网络容量的使用,这使得模型在顺序学习环境中具有更好的可扩展性和灵活性。

(6)顺序学习多个任务时,网络权重的重用率百分比
image.png
表格中的每个单元格显示了从任务 ti​到任务 tj​权重重用的百分比。

  • 表格中的数值越高,表示从先前任务到当前任务的权重重用率越高。这表明模型能够将先前任务学到的知识应用到新任务中,减少了从头开始学习的需求。
  • 不同任务之间的权重重用率会有显著差异。这与任务之间的相似性有关。例如,相似的任务(如MNIST和NotMNIST)会有更高的权重重用率。
  • 通过分析权重重用率,可以评估模型在不同任务之间的适应性和灵活性。高权重重用率表明模型能够更好地利用已有知识,而低权重重用率需要模型进行更多的调整以适应新任务。

(7)验证集准确率 A1 与压缩百分比之间的关系
image.png
压缩百分比指的是网络中被移除的权重或单元的比例。每个点代表一个训练周期(Epoch)后,对应压缩百分比下的验证集准确率。三角形表示使用标准随机梯度下降(SGD)方法训练的模型的准确率,作为不进行压缩时的参考。

  • 图中显示了随着压缩百分比的增加,验证集准确率的变化。一般来说,随着压缩百分比的增加,准确率会有所下降,因为更多的权重被移除。
  • 尽管随着压缩百分比的增加,准确率会降低,但图中的点仍然显示出在较高压缩率下仍然可以获得相对较高的准确率。这表明通过适当的压缩,模型可以在减少参数数量的同时保持较好的性能。
  • 三角形表示不进行压缩时SGD方法的准确率。通过比较点和三角形,可以评估压缩对模型性能的影响。如果压缩后的准确率接近或仅略低于不压缩的准确率,这表明压缩是有效的。
  • 在实际应用中,需要在模型大小和准确率之间找到平衡。图中的点可以帮助确定最佳的压缩率,以在减少模型大小的同时最小化准确率的损失。

6 思考

(1)计算平均遗忘率?

  1. 定义遗忘率
    • 对于每个任务 $ \tau $,在顺序学习了任务t ) 之后,计算其测试集上的准确率 Aτ≤t​。
    • 同时定义一个随机分类器的准确率 $ A_{\tau R} $​,它使用任务 τ的类别信息进行随机分类。
    • 还定义一个多任务学习下的准确率 $ A_{\tau \leq t J} $​,即在同时学习t个任务的情况下,任务 τ的准确率。
  2. 计算单个任务的遗忘率
    • 使用以下公式计算每个任务 τ在学习了任务t之后的遗忘率:
      $ \rho_{\tau \leq t} = \frac{A_{\tau \leq t} - A_{\tau R}}{A_{\tau \leq t J} - A_{\tau R} - 1} $
    • 其中:
      • $ A_{\tau \leq t} $是在顺序学习了任务 t之后,任务 τ的准确率。
      • $ A_{\tau R} $​是随机分类器的准确率。
      • $ A_{\tau \leq t J} $​是同时学习t个任务的情况下,任务 τ的准确率。
  3. 计算平均遗忘率
    • 对于每个任务 τ,从 τ = 1 到 τ = t ,计算 $ \rho_{\tau \leq t} $​。
    • 然后取这些遗忘率的平均值,得到平均遗忘率:
      $ \rho_{\leq t} = \frac{1}{t} \sum_{\tau=1}^{t} \rho_{\tau \leq t} $

这个平均遗忘率 ρ≤t​反映了在顺序学习了t个任务之后,模型相对于随机分类器和多任务学习情况的平均遗忘程度。值越接近0,表示模型遗忘旧任务的程度越低,学习新任务时保持旧任务知识的能力越强。
(2)本文从几个新颖的角度探讨了连续学习的性能,包括平均遗忘率、超参数对平均遗忘率的影响、网络容量的使用情况、验证集准确率 _A_1 与压缩百分比之间的关系、网络权重的重用率百分比。
(3)算法中的硬是什么意思?

  • 在传统的软注意力机制中,注意力权重是连续的,通常在0到1之间变化,表示对输入的不同部分的关注程度。而在硬注意力机制中,注意力权重被设计为接近二值化(即0或1),这种二值化是通过训练过程中的退火参数s来实现的。
  • 这种二值化有助于在学习新任务时保持对先前任务的权重更新的控制,从而减少对先前任务知识的遗忘。
  • 硬注意力通过形成硬掩码(即确定性的二值掩码),为网络权重提供了一种稳定的保护机制。这些掩码在训练过程中学习得到,并在测试时固定,类似于硬件中的“掩码”,因而称为“硬”。
  • 硬注意力机制通过简单的门控机制(如Sigmoid函数)来控制权重的更新,使得部分权重在训练新任务时保持不变,简化了权重的动态调整过程。

(4)算法中的注意力是什么意思?

  • 注意力指的是模型对不同任务的关注度。通过学习任务相关的嵌入,模型能够识别并专注于对当前任务重要的网络部分。
  • 硬注意力机制允许模型在学习新任务时,有选择性地更新权重。对于那些对先前任务重要的权重,通过硬掩码将其保护起来,减少或避免更新,从而减少遗忘。
  • 类似于人类的注意力机制,硬注意力能够让模型在处理不同任务时,动态地调整其关注的网络路径。这有助于模型在面对新任务时能够快速适应,同时保留对旧任务的记忆。
目录
相关文章
|
8天前
|
算法 JavaScript 前端开发
第一个算法项目 | JS实现并查集迷宫算法Demo学习
本文是关于使用JavaScript实现并查集迷宫算法的中国象棋demo的学习记录,包括项目运行方法、知识点梳理、代码赏析以及相关CSS样式表文件的介绍。
第一个算法项目 | JS实现并查集迷宫算法Demo学习
|
12天前
|
XML JavaScript 前端开发
学习react基础(1)_虚拟dom、diff算法、函数和class创建组件
本文介绍了React的核心概念,包括虚拟DOM、Diff算法以及如何通过函数和类创建React组件。
15 2
|
2月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
99 9
|
2月前
|
算法 Java
掌握算法学习之字符串经典用法
文章总结了字符串在算法领域的经典用法,特别是通过双指针法来实现字符串的反转操作,并提供了LeetCode上相关题目的Java代码实现,强调了掌握这些技巧对于提升算法思维的重要性。
|
2月前
|
算法 NoSQL 中间件
go语言后端开发学习(六) ——基于雪花算法生成用户ID
本文介绍了分布式ID生成中的Snowflake(雪花)算法。为解决用户ID安全性与唯一性问题,Snowflake算法生成的ID具备全局唯一性、递增性、高可用性和高性能性等特点。64位ID由符号位(固定为0)、41位时间戳、10位标识位(含数据中心与机器ID)及12位序列号组成。面对ID重复风险,可通过预分配、动态或统一分配标识位解决。Go语言实现示例展示了如何使用第三方包`sonyflake`生成ID,确保不同节点产生的ID始终唯一。
go语言后端开发学习(六) ——基于雪花算法生成用户ID
|
4天前
|
传感器 算法 C语言
基于无线传感器网络的节点分簇算法matlab仿真
该程序对传感器网络进行分簇,考虑节点能量状态、拓扑位置及孤立节点等因素。相较于LEACH算法,本程序评估网络持续时间、节点死亡趋势及能量消耗。使用MATLAB 2022a版本运行,展示了节点能量管理优化及网络生命周期延长的效果。通过簇头管理和数据融合,实现了能量高效和网络可扩展性。
|
1天前
|
算法 数据挖掘
基于粒子群优化算法的图象聚类识别matlab仿真
该程序基于粒子群优化(PSO)算法实现图像聚类识别,能识别0~9的数字图片。在MATLAB2017B环境下运行,通过特征提取、PSO优化找到最佳聚类中心,提高识别准确性。PSO模拟鸟群捕食行为,通过粒子间的协作优化搜索过程。程序包括图片读取、特征提取、聚类分析及结果展示等步骤,实现了高效的图像识别。
|
1月前
|
算法 BI Serverless
基于鱼群算法的散热片形状优化matlab仿真
本研究利用浴盆曲线模拟空隙外形,并通过鱼群算法(FSA)优化浴盆曲线参数,以获得最佳孔隙度值及对应的R值。FSA通过模拟鱼群的聚群、避障和觅食行为,实现高效全局搜索。具体步骤包括初始化鱼群、计算适应度值、更新位置及判断终止条件。最终确定散热片的最佳形状参数。仿真结果显示该方法能显著提高优化效率。相关代码使用MATLAB 2022a实现。
|
1月前
|
算法 数据可视化
基于SSA奇异谱分析算法的时间序列趋势线提取matlab仿真
奇异谱分析(SSA)是一种基于奇异值分解(SVD)和轨迹矩阵的非线性、非参数时间序列分析方法,适用于提取趋势、周期性和噪声成分。本项目使用MATLAB 2022a版本实现从强干扰序列中提取趋势线,并通过可视化展示了原时间序列与提取的趋势分量。代码实现了滑动窗口下的奇异值分解和分组重构,适用于非线性和非平稳时间序列分析。此方法在气候变化、金融市场和生物医学信号处理等领域有广泛应用。
|
1月前
|
资源调度 算法
基于迭代扩展卡尔曼滤波算法的倒立摆控制系统matlab仿真
本课题研究基于迭代扩展卡尔曼滤波算法的倒立摆控制系统,并对比UKF、EKF、迭代UKF和迭代EKF的控制效果。倒立摆作为典型的非线性系统,适用于评估不同滤波方法的性能。UKF采用无迹变换逼近非线性函数,避免了EKF中的截断误差;EKF则通过泰勒级数展开近似非线性函数;迭代EKF和迭代UKF通过多次迭代提高状态估计精度。系统使用MATLAB 2022a进行仿真和分析,结果显示UKF和迭代UKF在非线性强的系统中表现更佳,但计算复杂度较高;EKF和迭代EKF则更适合维数较高或计算受限的场景。
下一篇
无影云桌面