阅读时间:2023-10-23
1 介绍
年份:2023
作者:沈明格,陈德虎,滕仁。温州自然灾害立体智能监测预警重点实验室,温州工业大学
期刊:IEEE Access
引用量:0
作者提出了一种新颖的终身学习框架,利用元学习来学习任务之间的相似性表示,并防止遗忘先前的知识。该框架包括一个跨域三元组网络(CDTN),用于学习域不变的相似性表示,一个自注意模块,用于增强相似性特征的提取,以及一个软注意网络(SAN),根据学习到的相似性表示为任务分配不同的权重。
垃圾论文,太水了,文献标注全是错的,牛头不对马嘴。
2 相关研究
ICARL算法【 icarl: Incremental classifier and represen tation learning】,该算法使用教师网络和学生网络,以少量训练样本快速收敛所有已学习的任务。这种方法在学习新任务时只需要存储前一任务的少量样本,从而减少了存储开销。
GEM【 Gradient episodic memory for continual learning.】存储先前任务的梯度,确保新任务的梯度更新与先前任务正交。这减少了先前知识的干扰。
LwF 【 Learning without forgetting】限制只对与先前任务一致的参数进行更改。EWC 【Overcoming catastrophic forgetting in neural networks】使用先前训练的Fisher信息矩阵来衡量参数的重要性。然而,当任务很多时,这种方法可能会对网络造成过多的限制,并阻碍新的学习。一些方法,如SI算法[45],通过考虑从先前任务到新任务的参数变化来解决这个问题。
【‘‘ITAML: An incremental task-agnostic meta-learning approach,’’ in Proc. IEEE/CVF Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2020】将元学习方法应用于获取通用参数,这些参数不特定于旧任务或新任务,以防止灾难性遗忘。
【Experience replay for continual learning,’’ in Proc. Adv. Neural Inf. Process. Syst., vol. 32, 2019】经验重放。
3 创新点
采用元学习的方法,设计了一个跨领域三元组网络(CDTN),用于学习领域不变的相似性表示。该网络通过自注意机制,加强相似性特征的提取,并通过软注意网络(SAN)根据学习到的相似性表示为不同任务分配不同的权重。
4 模型
第一阶段中,跨领域三元组网络(CDTN)可以学习任务的相似性表示,不仅在相同领域中,而且在不同领域中。使用最大平均差异(MMD)来衡量跨领域分布差异。
在第二阶段,提议了一个软注意力网络(SAN),根据任务的相似性信息获取任务的具体注意力图。
LFEM模型中,特征图A首先通过三个1×1卷积层转换为B、C和D。然后,B和C被重新排列并相乘,通过Softmax函数获得注意力图S。最后,特征图D与S相乘,得到的特征图与A相加,得到最终的特征图E。
最后SAN使用交叉熵损失和随机梯度下降 (SGD) 来训练。
5 实验结果分析
(1)性能评估
评价指标:平均准确率AA、平均遗忘率AF
PackNet [12]和HAT [71]的容量有限,在新任务上的表现比我们的方法差。但它们通过锁定任务参数使用掩码来保留所有知识。EWC [7]和IMM [72]随着时间的推移仍然会遗忘。GEM [41]和ICARL [9]也会有一定程度的遗忘,但它们需要存储新任务的训练样本,这需要更多的空间。
(2)模型容量的影响
具有高容量的模型可以学习更多的任务。
当学习新任务时,会使用更多的权重。在训练过程中,使用率会首先缓慢下降,然后加快直到停止。这意味着网络可以缩小10%到50%,这取决于任务。当学习第四个任务时,使用的新参数较少,因为它与任务2相似。该方法利用任务相似性来改善学习。但是,在学习第8个任务时,没有类似的任务之前,前5个任务的使用量增加了约10%。与学习相似任务时相比,该方法使用的参数比PackNet少25%到80%,使用的参数比HAT少15%到70%。
表中显示了模型在多任务分类方面的表现。即使在CIFAR-100数据集中学习10个任务时,准确率也保持一致,没有忘记。当添加更多任务时,旧任务变得更好。这是因为该方法使用任务之间的相似性和来自损失函数的稀疏性来连续学习多个任务。
(3) 消融研究ablation study
仅有CDTN使平均准确率提高了约4%,平均遗忘率减少了近0.2%。这表明任务相似性信息有助于学习新任务。此外,在终身学习步骤中使用的LLEF将平均准确率提高了超过2%,证明了LLEF的非常有效。
6 思考
(1)第一阶段的元学习部分是如何实现的
具体的实现步骤是如何的?需要阅读代码进一步了解。
(2)模型的容量是怎么去评估的?如何计算得到当前任务下的模型容量是多少?