【博士每天一篇文献-算法】Learning without forgetting

简介: 本文提出了一种名为"无忘记学习"(Learning without Forgetting, LWF)的算法,它允许在不牺牲原有任务性能的情况下,通过仅使用新任务的数据来训练卷积神经网络以学习新的视觉能力。

阅读时间:2023-10-29

1 介绍

年份:2016
作者:李志忠; 德里克·霍伊姆,伊利诺伊大学
期刊:IEEE transactions on pattern analysis and machine intelligence
引用量:3353
提出一种名为"无忘记学习"的算法,用于在卷积神经网络(Convolutional Neural Network, CNN)中添加新的视觉能力,同时保留原有的能力,而无需使用旧的训练数据。这种算法只使用新任务的数据来训练网络,同时保持原有的能力。相比于常用的特征提取和微调适应技术,以及可能无法获取的原始任务数据的多任务学习,这种方法表现出色。该方法在计算效率上高效,不需要保留或重新应用训练数据,并且在部署上非常简单。通过更多实验证明了作者之前工作的可行性,并与其他方法进行了比较。该论文与诸如蒸馏网络、微调、多任务学习、领域适应、迁移学习、终身学习和不断学习等方法相关。该论文的方法提供了一种更直接的方式来保留对于原有任务重要的表示,相对于微调,它在大多数实验中提高了原有任务和新任务的性能。
类似算法,iead来源: Active long term memory networks
本文主要区别在于用于训练的权重衰减正则化和在完全微调之前使用的预热步骤。
本文使用大型数据集来训练我们的初始网络(例如ImageNet),然后从较小的数据集(例如PASCAL VOC)扩展到新任务。而A-LTM则使用小型数据集进行旧任务和大型数据集进行新任务。相比之下,LWF的实验表明,在没有原始任务数据的情况下,可以保持旧任务的良好性能,同时在新任务上表现得与或者有时比微调更好。

2 创新点

无遗忘学习方法LWF可以看作是反馈网络【 Distilling the knowledge in a neural network 】和微调【 Rich feature hierarchies for accurate object detection and semantic segmentation 】的结合。
微调从在相关数据丰富问题上训练的现有网络的参数初始化,并通过使用较低的学习率为新任务优化参数,找到新的局部最优点。反馈网络的思想是通过学习较简单的网络参数,使其在训练集或大型无标签数据集上产生与一个更复杂的网络集合相同的输出。LWF的方法不同之处在于,LWF使用相同数据来监督学习新任务并为旧任务提供无监督输出引导来解决适用于新旧任务的一组参数。

3 算法

(1)对比的方法
在这里插入图片描述

  • 特征提取: 这种方法使用预训练的深度卷积神经网络(CNN)从图像中提取特征。这些特征然后用于训练新任务的分类器。虽然这种方法不修改原始网络并且可以取得良好的结果,但提取的特征可能不专门用于新任务,并且可能需要进行微调。
  • 微调: 这种方法修改现有CNN的参数以训练新任务。输出层通过随机初始化权重进行扩展,并使用较小的学习率来最小化新任务的损失。有时,网络的某些层被冻结以防止过拟合。微调是使共享参数更具有辨别性的方法。这种方法通常优于特征提取。
  • 多任务学习: 这种方法旨在通过结合所有任务的共同知识来同时改进所有任务。网络的底层是共享的,而顶层是针对特定任务的。多任务学习需要所有任务的数据都存在。
  • 添加新节点到每个网络层: 这种方法在网络的每个层中添加新节点来学习新的辨别特征,同时保留原始网络参数。这可以扩展网络中的参数数量,如果没有足够的训练数据来学习新的参数,则可能表现不佳,因为这种方法需要从头开始训练大量参数。作者尝试扩展原始网络的全连接层,但发现这种扩展在LWF的方法上没有提供改进。

(2)相关的研究
知识蒸馏的方法,其中知识是从一个大型网络或网络组件传递给一个较小的网络,以实现高效部署。

  • A-LTM【Active long term memory networks】是在独立开发的基础上几乎相同的方法,但在实验和结论方面有很大的差异。方法的主要差异在于训练中使用的权重衰减正则化和在完全微调之前使用的预热步骤。本文作者使用大型数据集来训练初始网络(例如ImageNet),然后从较小的数据集(例如PASCAL VOC)扩展到新任务,而A-LTM使用小型数据集用于旧任务和大型数据集用于新任务。A-LTM中的实验发现了比LWF有更大的微调损失,并得出结论认为维护原始任务的数据对于维持性能是必要的。相比之下,实验证明LWF可以在不访问原始任务数据的情况下保持旧任务的良好性能,并在新任务中表现得和甚至更好。作者分析主要的区别在于旧任务-新任务配对的选择,并且观察到由于选择的差异(部分原因是预热步骤)导致的微调对旧任务性能的下降较小。作者认为从训练有素的网络开始并使用较少的训练数据添加任务的实验从实用角度更具有有效性。
  • Less Forgetting Learning【Less-forgetting learning in deep neural networks】也是一种类似的方法,通过阻止共享表示发生变化来保持旧任务的性能。该方法认为任务特定的决策边界不应改变,并保持旧任务的最后一层不变,LWF的方法则通过限制旧任务的输出变化来同时优化共享表示和最终层。LWF在新任务上表现优于Less Forgetting Learning。

(3)算法步骤
首先对于一个新任务来说,我们模型已经有了在之前任务训练的出来的共享参数θs​和旧任务参数θn​
,对于当前新到来的任务我们有了Xn​和Yn​。
初始化:
首先是根据新任务数据拟合出来一个旧任务标签
$$Yo​=CNN(Xn​,θs​,θo​)$$
随机初始化:θn​
训练:
旧任务预测出的结果:
$$Y_o'=CNN(X_n,\theta_s',\theta_o') $$
新任务预测结果:

$$ Yn′​=CNN(Xn​,θs′​,θn′​)$$
最后定义损失函数即可:
$$Loss=λ0​Lold​(Yo​,Yo′​)+Lnew​(Yn​,Yn′​)+R(θs​,θo​,θn​) $$
其中对于旧任务预测 $L_{old} $​,作者使用了蒸馏的交叉熵损失函数的方式,即
$$ L_{old}(Y_o,Y_o') = -\sum_{i=1}^{l}y_o'^{(i)}logy_o'^{(i)} $$
$$y_o'^{(i)} = \frac{({y_o'^{(i)})}^{1/T}}{\sum_j({y_o'^{(i)})}^{1/T}} $$
对于新任务 $ L_{new} $​,采用交叉熵损失函数
$$ L_{new} = -y_n log{\hat{y_n}} $$​
使用带有正则化的 SGD 训练网络
1) 首先固定θs​和θo​不变,然后使用新任务数据集训练θn​ 直至收敛;
2) 然后再联合训练所有参数θs​、 θo​和 θn​ 直至网络收敛。

4 实验分析

(1)实验设计
原始旧任务数据集:ImageNet [4]的ILSVRC 2012子集和Places365-standard [30]数据集
新任务大规模数据集:PASCAL VOC 2012图像分类[31](VOC)、Caltech-UCSD Birds-200-2011细粒度分类[32](CUB)和MIT室内场景分类[33](Scenes)、MNIST [34]
ImageNet [4]的ILSVRC 2012子集和Places365-standard [30]数据集、
权重:使用开放的ImageNet和Places365-standard上预训练的原始网络模型
(2)分析角度

  • 添加单个新任务到网络中的效果研究
  • 逐一添加多个任务的效果研究
  • 考察数据集大小对结果的影响
  • 考察网络设计对结果的影响
  • 进行消融研究:探究不同的保留响应损失函数、扩展网络结构以及使用较低学习率进行微调等方法对保持原有任务性能的有效性

5 代码

https://github.com/ngailapdi/LWF/blob/master/model.py

6 思考

LwF可以看作是蒸馏网络和微调的结合。LwF的整个训练过程和联合训练(Joint Training)有点类似,但不同的是LwF 不需要旧任务的数据和标签,而是用KD蒸馏损失去平衡旧任务的性能,完成了不需要访问任何旧任务数据的增量训练。

目录
相关文章
|
2月前
|
机器学习/深度学习 人工智能 资源调度
【博士每天一篇文献-算法】连续学习算法之HAT: Overcoming catastrophic forgetting with hard attention to the task
本文介绍了一种名为Hard Attention to the Task (HAT)的连续学习算法,通过学习几乎二值的注意力向量来克服灾难性遗忘问题,同时不影响当前任务的学习,并通过实验验证了其在减少遗忘方面的有效性。
58 12
|
2月前
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
138 9
|
2月前
|
存储 机器学习/深度学习 算法
【博士每天一篇文献-算法】连续学习算法之RWalk:Riemannian Walk for Incremental Learning Understanding
RWalk算法是一种增量学习框架,通过结合EWC++和修改版的Path Integral算法,并采用不同的采样策略存储先前任务的代表性子集,以量化和平衡遗忘和固执,实现在学习新任务的同时保留旧任务的知识。
79 3
|
13天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于MSER和HOG特征提取的SVM交通标志检测和识别算法matlab仿真
### 算法简介 1. **算法运行效果图预览**:展示算法效果,完整程序运行后无水印。 2. **算法运行软件版本**:Matlab 2017b。 3. **部分核心程序**:完整版代码包含中文注释及操作步骤视频。 4. **算法理论概述**: - **MSER**:用于检测显著区域,提取图像中稳定区域,适用于光照变化下的交通标志检测。 - **HOG特征提取**:通过计算图像小区域的梯度直方图捕捉局部纹理信息,用于物体检测。 - **SVM**:寻找最大化间隔的超平面以分类样本。 整个算法流程图见下图。
|
8天前
|
算法
基于粒子群算法的分布式电源配电网重构优化matlab仿真
本研究利用粒子群算法(PSO)优化分布式电源配电网重构,通过Matlab仿真验证优化效果,对比重构前后的节点电压、网损、负荷均衡度、电压偏离及线路传输功率,并记录开关状态变化。PSO算法通过迭代更新粒子位置寻找最优解,旨在最小化网络损耗并提升供电可靠性。仿真结果显示优化后各项指标均有显著改善。
|
3天前
|
机器学习/深度学习 算法 数据挖掘
基于GWO灰狼优化的GroupCNN分组卷积网络时间序列预测算法matlab仿真
本项目展示了基于分组卷积神经网络(GroupCNN)和灰狼优化(GWO)的时间序列回归预测算法。算法运行效果良好,无水印展示。使用Matlab2022a开发,提供完整代码及详细中文注释。GroupCNN通过分组卷积减少计算成本,GWO则优化超参数,提高预测性能。项目包含操作步骤视频,方便用户快速上手。
|
5天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于WOA鲸鱼优化的GroupCNN分组卷积网络时间序列预测算法matlab仿真
本项目展示了一种基于WOA优化的GroupCNN分组卷积网络时间序列预测算法。使用Matlab2022a开发,提供无水印运行效果预览及核心代码(含中文注释)。算法通过WOA优化网络结构与超参数,结合分组卷积技术,有效提升预测精度与效率。分组卷积减少了计算成本,而WOA则模拟鲸鱼捕食行为进行优化,适用于多种连续优化问题。
|
6天前
|
机器学习/深度学习 算法 5G
基于BP神经网络的CoSaMP信道估计算法matlab性能仿真,对比LS,OMP,MOMP,CoSaMP
本文介绍了基于Matlab 2022a的几种信道估计算法仿真,包括LS、OMP、NOMP、CoSaMP及改进的BP神经网络CoSaMP算法。各算法针对毫米波MIMO信道进行了性能评估,通过对比不同信噪比下的均方误差(MSE),展示了各自的优势与局限性。其中,BP神经网络改进的CoSaMP算法在低信噪比条件下表现尤为突出,能够有效提高信道估计精度。
20 2
|
15天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于GA遗传优化的GroupCNN分组卷积网络时间序列预测算法matlab仿真
该算法结合了遗传算法(GA)与分组卷积神经网络(GroupCNN),利用GA优化GroupCNN的网络结构和超参数,提升时间序列预测精度与效率。遗传算法通过模拟自然选择过程中的选择、交叉和变异操作寻找最优解;分组卷积则有效减少了计算成本和参数数量。本项目使用MATLAB2022A实现,并提供完整代码及视频教程。注意:展示图含水印,完整程序运行无水印。
|
14天前
|
算法 决策智能
基于禁忌搜索算法的VRP问题求解matlab仿真,带GUI界面,可设置参数
该程序基于禁忌搜索算法求解车辆路径问题(VRP),使用MATLAB2022a版本实现,并带有GUI界面。用户可通过界面设置参数并查看结果。禁忌搜索算法通过迭代改进当前解,并利用记忆机制避免陷入局部最优。程序包含初始化、定义邻域结构、设置禁忌列表等步骤,最终输出最优路径和相关数据图表。