【博士每天一篇文献-算法】改进的PNN架构Progressive learning A deep learning framework for continual learning

简介: 本文提出了一种名为“Progressive learning”的深度学习框架,通过结合课程选择、渐进式模型容量增长和剪枝机制来解决持续学习问题,有效避免了灾难性遗忘并提高了学习效率。

阅读时间:2023-12-24

1 介绍

年份:2020
作者:Haytham Fayek,Lawrence Cavedon,Hong Ren Wu,皇家墨尔本理工大学计算机技术学院
期刊: Neural Networks
引用量:43
Fayek H M, Cavedon L, Wu H R. Progressive learning: A deep learning framework for continual learning[J]. Neural Networks, 2020, 128: 345-357.
提出了一种名为“Progressive learning”的深度学习框架,解决持续学习问题。Progressive learning 框架包含三个主要步骤,分别是Curriculum(课程)、Progression(渐进)和Pruning(剪枝)。课程是指主动从一组候选任务中选择一个任务进行学习。渐进是指通过添加新参数来增加模型的容量,这些新参数利用之前任务中学习到的参数,同时学习当前新任务的数据,而不会受到灾难性遗忘的影响。剪枝是指用来抵消随着学习更多任务而增加的参数数量,同时减轻不相关的先前知识可能对当前任务造成的负面前向迁移。
image.png
image.png
image.png

2 创新点

  1. 全新的框架: 提出了Progressive learning框架,这是一个针对持续学习设计的深度学习框架,它综合了课程(Curriculum)、渐进(Progression)和剪枝(Pruning)三个步骤。
  2. 主动任务选择: 通过课程过程主动选择下一个要学习的任务,这一策略考虑了当前的知识状态,并可能导致任务难度的自然发展。
  3. 模型容量增长: 渐进过程通过添加新参数(称为渐进块)来增加模型的容量,这些新参数利用之前任务中学习到的参数,同时避免了灾难性遗忘。
  4. 参数剪枝机制: 剪枝过程用于抵消随着学习新任务而增加的参数数量,并减少不相关的先前知识可能对当前任务造成的负面前向迁移。
  5. 避免灾难性遗忘: 通过仅训练新添加的参数,而保持之前块中的参数不变,从而避免了在学习新任务时对旧任务知识的灾难性遗忘。
  6. 特征重用: 使用连接操作而不是求和操作来鼓励特征的重用,这简化了模型的优化并改善了梯度流动。
  7. 贪婪逐层剪枝: 基于深度网络中各层特征特异性不同的直觉,提出了一种贪婪逐层剪枝方法,该方法按层剪枝权重,以减少参数数量而不影响性能。
  8. 跨领域评估: 在图像识别和语音识别两个领域对Progressive learning进行了评估,证明了其在相关任务中相比于独立学习、迁移学习、多任务学习和相关持续学习基线的优势。
  9. 更高效的学习速度: 展示了Progressive learning在相关任务中可以更快地收敛到更好的泛化性能,同时使用更少的专用参数。

3 相关研究

3.1 相关概念

(1)负面前向迁移(Negative Forward Transfer)
是指在一个学习系统尝试学习新任务时,之前学到的知识或经验对新任务的性能产生负面影响的现象。这通常发生在新任务与之前的任务不相关或者只有很少相关性的情况下。
正向迁移(Positive Transfer)是指从先前的任务中学习到的知识中获益,这些知识可以帮助系统更快地学习新任务然而,如果先前学到的知识与新任务的要求不一致或相冲突,就可能导致负面前向迁移,从而干扰新任务的学习过程,使得性能下降。例如,假设一个机器学习模型首先在识别猫的图片上进行了训练,然后转向识别狗的图片的任务。如果模型在识别猫方面学得太好,它可能会错误地将一些狗的图片识别为猫,因为它之前学到的特征对于新任务不够通用,这就是负面前向迁移的一个例子。

3.2 相关论文

(1)参数正则化范式
旨在为先前任务找到参数重要性的度量,这些度量可以用来在新任务训练期间适应性地调整或惩罚它们的扰动,从而减轻灾难性遗忘,例如弹性权重整合(Kirkpatrick et al., 2017)和突触智能(Zenke, Poole, & Ganguli, 2017)。
(2)功能正则化范式
通过在损失函数中引入正则化项来对抗灾难性遗忘,该项在新任务训练期间惩罚模型在一些先前任务上训练的输出与当前任务训练时的偏差,例如学习不遗忘(Li & Hoiem 2016)和通过蒸馏适应(Hou, Pan, Loy, Wang, & Lin, 2018)。
Li Z, Hoiem D. Learning without forgetting[J]. IEEE transactions on pattern analysis and machine intelligence, 2017, 40(12): 2935-2947.
Hou S, Pan X, Loy C C, et al. Lifelong learning via progressive distillation and retrospection[C]//Proceedings of the European Conference on Computer Vision (ECCV). 2018: 437-452.
(3)架构范式通过为每个新任务添加新的自适应参数来规避灾难性遗忘,例如块模块化神经网络(Terekhov et al., 2015)、渐进神经网络(Rusu et al., 2016)和残差适配器(Rebuffi, Bilen, & Vedaldi, 2017)。
Terekhov A V, Montone G, O’Regan J K. Knowledge transfer in deep block-modular neural networks[C]//Biomimetic and Biohybrid Systems: 4th International Conference, Living Machines 2015, Barcelona, Spain, July 28-31, 2015, Proceedings 4. Springer International Publishing, 2015: 268-279.
Rusu A A, Rabinowitz N C, Desjardins G, et al. Progressive neural networks[J]. arXiv preprint arXiv:1606.04671, 2016.
Rebuffi S A, Bilen H, Vedaldi A. Learning multiple visual domains with residual adapters[J]. Advances in neural information processing systems, 2017, 30.
(3)经验重放范式存储先前任务中看到的数据,要么是直接的,例如梯度情景记忆(Lopez-Paz & Ranzato, 2017),要么通过生成模型进行压缩(Shin, Lee, Kim, & Kim, 2017),并在训练中利用多任务目标,通过在训练过程中重放先前任务中看到的数据,结合新引入的任务(Isele & Cosgun, 2018)。
Lopez-Paz D, Ranzato M A. Gradient episodic memory for continual learning[J]. Advances in neural information processing systems, 2017, 30.
Shin H, Lee J K, Kim J, et al. Continual learning with deep generative replay[J]. Advances in neural information processing systems, 2017, 30.
Isele D, Cosgun A. Selective experience replay for lifelong learning[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2018, 32(1).
(4)元学习范式,即学习如何学习
Al-Shedivat M, Bansal T, Burda Y, et al. Continuous adaptation via meta-learning in nonstationary and competitive environments[J]. arXiv preprint arXiv:1710.03641, 2017.

4 算法

4.1 算法步骤

  • 初始化:选择一组候选任务和相应的数据集。
  • 对于每个任务:
    • 使用课程策略选择下一个任务。
    • 使用渐进过程训练新的任务模型(渐进块)。
    • 如果是学习后续任务,使用剪枝过程来优化模型。
    • 更新已学习任务和处理过的数据集的集合。

4.2 算法原理

  1. 课程(Curriculum):从一组候选任务中确定接下来要学习的任务,考虑当前的知识状态。对每个候选任务训练和评估一个简单的分类或回归模型,使用已有的特征。选择性能最好的任务作为下一个学习任务,即利用已学习特征最能解决的任务。
  2. 渐进(Progression):通过添加新参数来增加模型的容量,以便学习新任务,同时利用之前任务中学到的参数。为每个新任务实例化并训练一个新的多层神经网络(称为渐进块),它除了层内连接外,还从现有渐进块的相应前层接收输入。即使用连接操作将前一层的输出与之前所有块的前一层输出连接起来,形成输入。仅训练新添加的渐进块中的参数,而保持之前块中的参数不变,避免灾难性遗忘。

image.png

  1. 剪枝(Pruning):抵消由于连续学习新任务导致的参数数量增长,并减轻负面前向迁移。在每个渐进块训练收敛后,迭代地剪枝每层中权重大小最小的一部分,并继续训练以补偿被剪枝的权重。使用贪婪策略剪枝,按照从小到大的顺序逐渐增加剪枝量,直到性能下降为止,从而找到最优的剪枝量。基于网络中不同层学习到的特征特异性不同,初期层更倾向于被剪枝,因为它们学习到的特征更通用。

image.png

5 实验分析

(1)渐进学习方法的消融实验,对比11个图像识别任务上的准确性
image.png
可以观察到渐进学习在多数或所有任务上的准确率普遍高于独立学习。这表明通过渐进学习,模型能够更有效地利用先前任务学到的知识来提高新任务的学习效率和性能。独立学习在每个任务上的表现可能相对一致,因为它不依赖于之前任务学到的知识,每次任务都是从零开始学习。
(2)学习效率
image.png
根据课程策略的结果排序的,这可能意味着任务的难度或复杂性是逐渐增加的。
渐进学习曲线显示出比独立学习更快的下降趋势,这表明渐进学习方法能够更快地收敛到较低的验证误差,即学习速度更快。渐进学习之所以能够更快地学习,是因为它能够利用之前任务中学到的知识,避免了每次都从头开始学习。
(3)剪枝过程前后的对比
image.png
图A在应用剪枝过程之前,随着任务数量的增加,模型的总参数数量显著增加。剪枝过程有效地减少了模型的总参数数量,这有助于控制模型的复杂度,防止过拟合。
图B中固定参数的数量可能随着任务的增加而增加,但剪枝过程有助于减少这些参数的数量,这表明剪枝不仅影响当前正在训练的块,也可能影响之前块中的参数。
图C中自适应参数的数量代表了当前任务正在训练的新参数。剪枝过程可能对这些参数有显著影响,通过去除不重要的参数来优化模型结构。
(4)剪枝(Pruning)步骤对不同网络层的影响
image.png
与网络的深层相比,网络的初始层(接近输入层的层)更倾向于被剪枝。这表明初始层学到的特征在后续任务中可能更容易被替代或被认为是冗余的。由于在渐进学习中,先前块学到的特征可以在新的块中被重用,因此,后续块可以依赖于之前块中学习到的更通用的特征,从而减少了对初始层特征的需求。
(5)四个语音识别任务上的性能对比
image.png
四个语音识别任务可能包括自动语音识别(ASR)、语音情感识别(SER)、说话人识别(SR)和性别识别(GR)。
可以观察到,渐进学习方法在某些任务上可能表现更好,这表明它能够利用之前任务学到的知识来提高新任务的学习效率和性能。渐进学习方法可能在新任务上收敛得更快,因为它不需要每次都从零开始学习,而是可以利用之前任务中学到的特征。
(6)四个语音识别任务上的学习效率
image.png
渐进学习在多数或所有任务上显示出比独立学习更快的下降趋势,这表明它能够更快地学习并减少训练过程中的错误率。

6 思考

(1)本文在PNN的结构上结合了课程学习和剪枝的过程,创新型一般。

相关实践学习
达摩院智能语音交互 - 声纹识别技术
声纹识别是基于每个发音人的发音器官构造不同,识别当前发音人的身份。按照任务具体分为两种: 声纹辨认:从说话人集合中判别出测试语音所属的说话人,为多选一的问题 声纹确认:判断测试语音是否由目标说话人所说,是二选一的问题(是或者不是) 按照应用具体分为两种: 文本相关:要求使用者重复指定的话语,通常包含与训练信息相同的文本(精度较高,适合当前应用模式) 文本无关:对使用者发音内容和语言没有要求,受信道环境影响比较大,精度不高 本课程主要介绍声纹识别的原型技术、系统架构及应用案例等。 讲师介绍: 郑斯奇,达摩院算法专家,毕业于美国哈佛大学,研究方向包括声纹识别、性别、年龄、语种识别等。致力于推动端侧声纹与个性化技术的研究和大规模应用。
目录
相关文章
|
1月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
38 12
|
1月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
65 9
|
1月前
|
机器学习/深度学习 算法 机器人
【博士每天一篇文献-算法】改进的PNN架构Lifelong learning with dynamically expandable networks
本文介绍了一种名为Dynamically Expandable Network(DEN)的深度神经网络架构,它能够在学习新任务的同时保持对旧任务的记忆,并通过动态扩展网络容量和选择性重训练机制,有效防止语义漂移,实现终身学习。
41 9
|
1月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之HNet:Continual learning with hypernetworks
本文提出了一种基于任务条件超网络(Hypernetworks)的持续学习模型,通过超网络生成目标网络权重并结合正则化技术减少灾难性遗忘,实现有效的任务顺序学习与长期记忆保持。
23 4
|
1月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
65 3
|
6天前
|
算法 BI Serverless
基于鱼群算法的散热片形状优化matlab仿真
本研究利用浴盆曲线模拟空隙外形,并通过鱼群算法(FSA)优化浴盆曲线参数,以获得最佳孔隙度值及对应的R值。FSA通过模拟鱼群的聚群、避障和觅食行为,实现高效全局搜索。具体步骤包括初始化鱼群、计算适应度值、更新位置及判断终止条件。最终确定散热片的最佳形状参数。仿真结果显示该方法能显著提高优化效率。相关代码使用MATLAB 2022a实现。
|
6天前
|
算法 数据可视化
基于SSA奇异谱分析算法的时间序列趋势线提取matlab仿真
奇异谱分析(SSA)是一种基于奇异值分解(SVD)和轨迹矩阵的非线性、非参数时间序列分析方法,适用于提取趋势、周期性和噪声成分。本项目使用MATLAB 2022a版本实现从强干扰序列中提取趋势线,并通过可视化展示了原时间序列与提取的趋势分量。代码实现了滑动窗口下的奇异值分解和分组重构,适用于非线性和非平稳时间序列分析。此方法在气候变化、金融市场和生物医学信号处理等领域有广泛应用。
|
29天前
|
算法
基于模糊控制算法的倒立摆控制系统matlab仿真
本项目构建了一个基于模糊控制算法的倒立摆控制系统,利用MATLAB 2022a实现了从不稳定到稳定状态的转变,并输出了相应的动画和收敛过程。模糊控制器通过对小车位置与摆的角度误差及其变化量进行模糊化处理,依据预设的模糊规则库进行模糊推理并最终去模糊化为精确的控制量,成功地使倒立摆维持在直立位置。该方法无需精确数学模型,适用于处理系统的非线性和不确定性。
基于模糊控制算法的倒立摆控制系统matlab仿真
|
7天前
|
资源调度 算法
基于迭代扩展卡尔曼滤波算法的倒立摆控制系统matlab仿真
本课题研究基于迭代扩展卡尔曼滤波算法的倒立摆控制系统,并对比UKF、EKF、迭代UKF和迭代EKF的控制效果。倒立摆作为典型的非线性系统,适用于评估不同滤波方法的性能。UKF采用无迹变换逼近非线性函数,避免了EKF中的截断误差;EKF则通过泰勒级数展开近似非线性函数;迭代EKF和迭代UKF通过多次迭代提高状态估计精度。系统使用MATLAB 2022a进行仿真和分析,结果显示UKF和迭代UKF在非线性强的系统中表现更佳,但计算复杂度较高;EKF和迭代EKF则更适合维数较高或计算受限的场景。
|
9天前
|
算法
基于SIR模型的疫情发展趋势预测算法matlab仿真
该程序基于SIR模型预测疫情发展趋势,通过MATLAB 2022a版实现病例增长拟合分析,比较疫情防控力度。使用SIR微分方程模型拟合疫情发展过程,优化参数并求解微分方程组以预测易感者(S)、感染者(I)和移除者(R)的数量变化。![]该模型将总人群分为S、I、R三部分,通过解析或数值求解微分方程组预测疫情趋势。