【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task

简介: 本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。

阅读时间:2023-12-28

1 介绍

年份:2018
作者:Joan Serrà,Sony AI;Dídac Surís,PhD student, Columbia University;Marius Miron,Earth Species Project
会议:International conference on machine learning. PMLR
引用量:1064
代码:https://github.com/joansj/hat
Serra J, Suris D, Miron M, et al. Overcoming catastrophic forgetting with hard attention to the task[C]//International conference on machine learning. PMLR, 2018: 4548-4557.
本文提出了一种基于任务的硬注意力机制(Hard Attention to the Task, HAT),通过学习几乎二值的注意力向量来保持先前任务的信息,同时不影响当前任务的学习。通过门控任务嵌入学习硬注意力掩码,这些掩码定义了网络权重的更新约束。利用先前任务的注意力向量来调整梯度,保护先前任务中重要的权重。通过调整超参数,控制学习知识的稳定性和紧凑性,使得方法适用于不同的应用场景。实验结果显示,HAT在不同的实验设置中都能显著降低遗忘率。HAT机制还提供了监控网络行为的能力,例如评估网络容量使用情况和权重重用情况。此外,还可以利用硬注意力掩码进行网络剪枝,实现模型压缩。
image.png
image.png
image.png

2 创新点

  1. 硬注意力机制:提出了一种硬注意力机制,通过学习几乎二值的注意力向量,同时对每个任务进行学习,以保持先前任务的信息,而不会影响当前任务的学习。
  2. 任务嵌入的门控:通过门控任务嵌入,使用反向传播和批量随机梯度下降(SGD),动态创建和销毁跨层的路径,这些路径在学习新任务时可以被保留。
  3. 累积注意力向量:通过累积先前任务的注意力向量来调整梯度,从而保持对先前任务重要的权重的更新,防止对这些权重进行大的更新。
  4. 压缩性和稳定性控制:通过引入超参数,控制学习知识的稳定性和紧凑性,使得该方法也适用于在线学习或网络压缩应用。
  5. 无需记忆模块:与需要使用记忆模块的方法不同,HAT不需要记忆模块,而是通过学习单元级别的掩码来约束网络权重的更新。
  6. 简化的超参数:该方法只有两个超参数,直观地控制了学习知识的稳定性和紧凑性,而且调整这些超参数对于获得良好性能并不关键。
  7. 监控能力:硬注意力机制提供了监控模型行为的性,例如通过计算条件掩码来评估哪些权重获得高注意力值,以及跨任务的权重重用情况。
  8. 网络剪枝:利用硬注意力掩码来评估网络权重的重要性,并剪枝最不重要的权重,从而实现网络的压缩。

3 相关研究

  1. 复习策略(Rehearsal):
    • 论文 (1995). Catastrophic forgetting, rehearsal and pseudorehearsal.
  2. 记忆模块(Memory Modules):
    • 论文: (2017). iCaRL: incremental classifier and representation learning.
  3. 伪复习(Pseudorehearsal):
    • 论文:(1995). Catastrophic forgetting, rehearsal and pseudorehearsal.
  4. 生成网络替代记忆模块(Generative Networks as Memory Modules):
    • 论文:(2017). A strategy for an uncompromising incremental learner.
  5. 减少表示重叠(Reducing Representational Overlap):
    • 论文:(1991). Using semi-distributed representations to overcome catastrophic forgetting in connectionist networks.
  6. 结构正则化(Structural Regularization):
    • 论文:(2017). Improved multitask learning through synaptic intelligence
  7. 增量时刻匹配(Incremental Moment Matching, IMM):
    • 论文:(2017)Overcoming catastrophic forgetting by incremental moment matching.
  8. 弹性权重巩固(Elastic Weight Consolidation, EWC):
    • 论文: (2017)Overcoming catastrophic forgetting in neural networks
  9. 路径网络(PathNet):
    • 论文:PathNet: evolution channels gradient descent in super neural networks.
  10. 渐进式神经网络(Progressive Neural Networks, PNNs):
  • 论文:(2016)Progressive neural networks.
  1. 突触智能 (Synaptic Intelligence, SI):
  • 论文:(2017). Improved multitask learning through synaptic intelligence.
  1. 打包网络 (PackNet):
  • 论文:(2017). PackNet: adding multiple tasks to a single network by iterative pruning
  1. 动态可扩展网络 (Dynamically Expandable Networks, DEN):
  • 论文: (2018). Lifelong learning with dynamically expandable networks.

4 算法

4.1 关键点

(1)前向传播
数据从输入层逐层传递,每层的输出都会与任务相关的注意力向量相乘。注意力向量 $ a_{l}^t $​是通过 $ \sigma (e_{l}^t \cdot s) $计算的,其中 σ 是Sigmoid函数, $ e_l^t $​是任务嵌入,通常使用高斯分布_N_(0,1) 进行初始化,s 是正的缩放参数。在学习新任务时,利用所有先前任务的注意力向量来计算累积注意力向量,以保持对先前任务重要的权重。
使用Sigmoid函数和正的缩放参数的目的是来构建一个伪步函数,允许梯度流动。训练过程中逐渐增加缩放参数,使得注意力向量在测试时接近二值。
注意力向量 $ a_l^t $​ 与每层的输出 $ h_l $​进行元素级乘法操作,生成新的层输出 $ h_l=a_l^t\odot h_l $
(2)反向传播
在训练过程中,梯度会根据累积的注意力向量进行调整,以保护对先前任务重要的权重。梯度调整是通过最小化当前层和前一层的累积注意力向量的元素来实现。
为了解决在训练过程中嵌入梯度变化不大的问题,引入了嵌入梯度补偿机制。通过除以退火Sigmoid函数的导数并乘以期望的补偿量,调整梯度。
并且在损失函数中添加正则化项来促进注意力向量的稀疏性,从而为未来任务保留模型容量。
(3)补充
本文提出可以利用硬注意力机制监控模型的行为,例如评估网络容量的使用情况和权重的重用情况。此外,还可以使用硬注意力掩码来评估网络权重的重要性,并进行剪枝,从而实现模型压缩。

2.2 算法步骤

image.png

  1. 初始化
    • 在训练开始之前,任务嵌入 $ e_l^t $​需要被初始化。通常使用高斯分布N(0, 1) 进行初始化。
  2. 前向传播
    • 在每次训练迭代中,输入数据通过网络的每层进行前向传播,生成每层的输出 $ h_l $​。
  3. 计算门控任务嵌入
    • 对于每层l和每个任务t,计算门控任务嵌入注意力向量 $ a_l^t $​。这通常通过一个门控函数(如Sigmoid函数)和缩放参数s来实现:
      $ a_l^t = \sigma(s \cdot e_l^t) $
    • 其中 σ是Sigmoid函数,s是一个正的缩放参数。
  4. 元素级乘法
    • 将注意力向量 $ a_l^t $​与每层的输出 $ h_l $​进行元素级乘法操作,生成新的层输出 $ h_l = a_l^t \odot h_l $。
  5. 反向传播
    • 在损失函数计算后,通过反向传播算法计算每层的梯度。对于任务嵌入 $ e_l^t $​的梯度计算如下:
      $ \frac{\partial L}{\partial e_l^t} = \frac{\partial L}{\partial h_l} \odot a_l^t \odot (1 - a_l^t) ∂elt​$
    • 其中L是损失函数, ⊙表示元素级乘法。
  6. 梯度补偿
    • 为了增强梯度并促进 $ e_l^t $​的学习,应用梯度补偿。梯度补偿的公式如下:
      $ q_ l^t = s{\text{max}} \cdot (\cosh(e_lt) + 1) \cdot \frac{\partial L}{\partial e_l^t} $
    • 这里 $ s_{\text{max}} $​是一个超参数,用于控制梯度补偿的强度。
  7. 更新任务嵌入
    • 使用梯度下降算法更新任务嵌入 $ e_l^t $​:
      $ e_l^t \leftarrow e_l^t - \eta \cdot q_l^t $
    • 其中 η是学习率。
  8. 退火参数s
    • 在训练过程中,逐渐增加缩放参数s 退火,使得 $ a_l^t $​逐渐接近二值化,从而在测试阶段更稳定:
      $ s = \frac{1}{s_{\text{max}} + \left( \frac{s_{\text{max}} - 1}{s_{\text{max}}} \right) \cdot \left( \frac{b}{B} \right)} $
    • 其中b是当前批次索引,B是总批次数。

5 实验分析

(1)退火参数 s 对于任务嵌入 $ e_l^t $​ 的梯度影响
image.png
Compensated这条曲线展示了应用了梯度补偿后的梯度分布。梯度补偿的目的是调整梯度的分布,使得它们在 $ e_l^t $​的期望活动范围内有更大的梯度值。
算法梯度更新中,通过除以退火Sigmoid函数的导数并乘以期望的补偿量,HAT算法调整梯度的分布和幅度,使得 $ e_l^t $​的梯度在训练过程中能够更有效地更新。
(2)平均遗忘比率 ρ≤t 随任务数 t的变化情况
image.png
该比率越接近0,表示模型遗忘旧任务的程度越低。HAT方法在所有任务数t$\ge$2 的情况下,遗忘比率始终低于其他基线方法。这表明HAT方法在减少灾难性遗忘方面非常有效。随着任务数的增加,HAT方法的遗忘比率增加幅度较小,显示出较好的稳定性和鲁棒性。

(3)平均遗忘比率
image.png
在学习第二个任务( ρ ≤ 2)和最后一个任务( ρ ≤ 8)之后的平均遗忘比率。这些数据是通过10次运行实验得出的,括号内是标准差。

  1. HAT方法的优势
    • HAT(硬注意力到任务)方法在学习第二个任务后的平均遗忘比率 ρ≤ 2 为 -0.02,是所有方法中最低的。这表明HAT在初始任务转换时遗忘最少。
    • 在学习最后一个任务后,HAT的 ρ ≤ 8 为 -0.06,同样是所有方法中最低的。这表明即使在多次任务转换后,HAT也能保持较低的遗忘率。
  2. 基线方法的表现
    • 其他基线方法如LFL(Less-Forgetting Learning)、LWF(Learning Without Forgetting)、SGD(随机梯度下降)等在学习第二个任务后遗忘比率较高,特别是LFL和LWF,其 ρ ≤ 2 分别为 -0.73 和 -0.14。
    • 在学习最后一个任务后,这些方法的遗忘比率进一步增加,例如LFL的 ρ ≤ 8为 -0.92,LWF的 ρ ≤ 8为 -0.80。
  3. 遗忘比率的减少
    • HAT方法在 ρ ≤ 2和 ρ ≤ 8 上的低遗忘比率表明其在学习新任务时能够有效地保留旧任务的知识。
  4. 标准差的稳定性
    • HAT方法的标准差较低,表明其在不同实验运行中表现稳定,不受随机因素的影响。
  5. 其他方法的敏感性
    • 一些方法如LFL和PathNet显示出较高的标准差,这表明它们对超参数、初始化或数据集的变化较为敏感。

(4)超参数 $ s_{max} $​和c对平均遗忘比率 ρ ≤ 8 的影响。

image.png

  • $ s_{max} $​:控制门控函数的极性或“硬度”。较高的 $ s_{max} $​值会使门控函数更接近单位步函数,从而在训练过程中更稳定地保留先前任务的信息。
  • c:控制模型的紧凑性。较高的 c 值会促使模型学习更紧凑的表示,减少对当前任务的容量使用。
  • 随着 $x s_{max} $​ 值的增加,平均遗忘比率 ρ ≤ 8 通常会降低。这表明更高的 $ s_{max} $​有助于更好地保留先前任务的知识。然而, $ s_{max} $​ 过高会限制模型对新任务的适应能力,因此需要找到一个平衡点。
  • c 值的增加会导致 ρ ≤ 8 增加,这表明更高的 c值会使模型在学习新任务时遗忘更多的旧任务知识。c 值较低时,模型会使用更多的网络容量来学习当前任务,从而减少遗忘。

(5)顺序任务学习时,网络容量的使用情况
image.png

  • 曲线显示了在顺序学习多个任务时,网络如何动态地调整其容量的使用。这表明HAT算法能够根据每个任务的需求调整网络的活跃部分。
  • 虚线垂直线表示任务切换点,可以看到在这些点附近,网络容量的使用率会发生变化。这反映了模型在适应新任务时对网络容量的重新分配。
  • 在某些任务中,网络容量的使用率会降低,这表明HAT算法在学习新任务时能够保留一部分网络容量用于未来任务。这种保留有助于减少灾难性遗忘。
  • HAT算法不仅能够处理单个任务,还能够在多个任务之间灵活地调整网络容量的使用,这使得模型在顺序学习环境中具有更好的可扩展性和灵活性。

(6)顺序学习多个任务时,网络权重的重用率百分比
image.png
表格中的每个单元格显示了从任务 ti​到任务 tj​权重重用的百分比。

  • 表格中的数值越高,表示从先前任务到当前任务的权重重用率越高。这表明模型能够将先前任务学到的知识应用到新任务中,减少了从头开始学习的需求。
  • 不同任务之间的权重重用率会有显著差异。这与任务之间的相似性有关。例如,相似的任务(如MNIST和NotMNIST)会有更高的权重重用率。
  • 通过分析权重重用率,可以评估模型在不同任务之间的适应性和灵活性。高权重重用率表明模型能够更好地利用已有知识,而低权重重用率需要模型进行更多的调整以适应新任务。

(7)验证集准确率 A1 与压缩百分比之间的关系
image.png
压缩百分比指的是网络中被移除的权重或单元的比例。每个点代表一个训练周期(Epoch)后,对应压缩百分比下的验证集准确率。三角形表示使用标准随机梯度下降(SGD)方法训练的模型的准确率,作为不进行压缩时的参考。

  • 图中显示了随着压缩百分比的增加,验证集准确率的变化。一般来说,随着压缩百分比的增加,准确率会有所下降,因为更多的权重被移除。
  • 尽管随着压缩百分比的增加,准确率会降低,但图中的点仍然显示出在较高压缩率下仍然可以获得相对较高的准确率。这表明通过适当的压缩,模型可以在减少参数数量的同时保持较好的性能。
  • 三角形表示不进行压缩时SGD方法的准确率。通过比较点和三角形,可以评估压缩对模型性能的影响。如果压缩后的准确率接近或仅略低于不压缩的准确率,这表明压缩是有效的。
  • 在实际应用中,需要在模型大小和准确率之间找到平衡。图中的点可以帮助确定最佳的压缩率,以在减少模型大小的同时最小化准确率的损失。

6 思考

(1)计算平均遗忘率?

  1. 定义遗忘率
    • 对于每个任务 $ \tau $,在顺序学习了任务t ) 之后,计算其测试集上的准确率 Aτ≤t​。
    • 同时定义一个随机分类器的准确率 $ A_{\tau R} $​,它使用任务 τ的类别信息进行随机分类。
    • 还定义一个多任务学习下的准确率 $ A_{\tau \leq t J} $​,即在同时学习t个任务的情况下,任务 τ的准确率。
  2. 计算单个任务的遗忘率
    • 使用以下公式计算每个任务 τ在学习了任务t之后的遗忘率:
      $ \rho_{\tau \leq t} = \frac{A_{\tau \leq t} - A_{\tau R}}{A_{\tau \leq t J} - A_{\tau R} - 1} $
    • 其中:
      • $ A_{\tau \leq t} $是在顺序学习了任务 t之后,任务 τ的准确率。
      • $ A_{\tau R} $​是随机分类器的准确率。
      • $ A_{\tau \leq t J} $​是同时学习t个任务的情况下,任务 τ的准确率。
  3. 计算平均遗忘率
    • 对于每个任务 τ,从 τ = 1 到 τ = t ,计算 $ \rho_{\tau \leq t} $​。
    • 然后取这些遗忘率的平均值,得到平均遗忘率:
      $ \rho_{\leq t} = \frac{1}{t} \sum_{\tau=1}^{t} \rho_{\tau \leq t} $

这个平均遗忘率 ρ≤t​反映了在顺序学习了t个任务之后,模型相对于随机分类器和多任务学习情况的平均遗忘程度。值越接近0,表示模型遗忘旧任务的程度越低,学习新任务时保持旧任务知识的能力越强。
(2)本文从几个新颖的角度探讨了连续学习的性能,包括平均遗忘率、超参数对平均遗忘率的影响、网络容量的使用情况、验证集准确率 _A_1 与压缩百分比之间的关系、网络权重的重用率百分比。
(3)算法中的硬是什么意思?

  • 在传统的软注意力机制中,注意力权重是连续的,通常在0到1之间变化,表示对输入的不同部分的关注程度。而在硬注意力机制中,注意力权重被设计为接近二值化(即0或1),这种二值化是通过训练过程中的退火参数s来实现的。
  • 这种二值化有助于在学习新任务时保持对先前任务的权重更新的控制,从而减少对先前任务知识的遗忘。
  • 硬注意力通过形成硬掩码(即确定性的二值掩码),为网络权重提供了一种稳定的保护机制。这些掩码在训练过程中学习得到,并在测试时固定,类似于硬件中的“掩码”,因而称为“硬”。
  • 硬注意力机制通过简单的门控机制(如Sigmoid函数)来控制权重的更新,使得部分权重在训练新任务时保持不变,简化了权重的动态调整过程。

(4)算法中的注意力是什么意思?

  • 注意力指的是模型对不同任务的关注度。通过学习任务相关的嵌入,模型能够识别并专注于对当前任务重要的网络部分。
  • 硬注意力机制允许模型在学习新任务时,有选择性地更新权重。对于那些对先前任务重要的权重,通过硬掩码将其保护起来,减少或避免更新,从而减少遗忘。
  • 类似于人类的注意力机制,硬注意力能够让模型在处理不同任务时,动态地调整其关注的网络路径。这有助于模型在面对新任务时能够快速适应,同时保留对旧任务的记忆。
目录
相关文章
|
18天前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
数据结构与算法系列学习之串的定义和基本操作、串的储存结构、基本操作的实现、朴素模式匹配算法、KMP算法等代码举例及图解说明;【含常见的报错问题及其对应的解决方法】你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
|
15天前
|
机器学习/深度学习 人工智能 自然语言处理
【EMNLP2024】基于多轮课程学习的大语言模型蒸馏算法 TAPIR
阿里云人工智能平台 PAI 与复旦大学王鹏教授团队合作,在自然语言处理顶级会议 EMNLP 2024 上发表论文《Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning》。
|
18天前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习(8)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
19天前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之顺序表【无论是王道考研人还真爱粉都能包会的;不然别给我家鸽鸽丢脸好嘛?】
顺序表的定义和基本操作之插入;删除;按值查找;按位查找等具体详解步骤以及举例说明
|
18天前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之单双链表精题详解(9)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
19天前
|
存储 Web App开发 算法
2024重生之回溯数据结构与算法系列学习之单双链表【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构之单双链表按位、值查找;[前后]插入;删除指定节点;求表长、静态链表等代码及具体思路详解步骤;举例说明、注意点及常见报错问题所对应的解决方法
|
18天前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之栈和队列精题汇总(10)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第3章之IKUN和I原达人之数据结构与算法系列学习栈与队列精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
19天前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之顺序表习题精讲【无论是王道考研人还真爱粉都能包会的;不然别给我家鸽鸽丢脸好嘛?】
顺序表的定义和基本操作之插入;删除;按值查找;按位查找习题精讲等具体详解步骤以及举例说明
|
19天前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习【无论是王道考研人还真爱粉都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构的基本概念;算法的基本概念、特性以及时间复杂度、空间复杂度等举例说明;【含常见的报错问题及其对应的解决方法】
|
18天前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之王道第2.3章节之线性表精题汇总二(5)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
IKU达人之数据结构与算法系列学习×单双链表精题详解、数据结构、C++、排序算法、java 、动态规划 你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!