万文长字总结「类别增量学习」的前世今生、开源工具包(2)

简介: 万文长字总结「类别增量学习」的前世今生、开源工具包

训练流程

对于新的任务的训练集,LwF的损失函数包括:

新类的标签监督信息:即新类对应的logits与标签的交叉熵损失(KL散度)

旧类的知识蒸馏:新旧模型在旧类上的的logits的交叉熵损失(包含温度系数:设置温度系数大于一,来增强模型对于类别相似性的编码能力)

参数偏移正则项,对于新模型参数关于旧模型参数偏移的正则项。

具体的伪代码如下:


5.2 iCaRL: Incremental Classifier and Representation Learning[7]


核心摘要

iCaRL可以视为Class-Incremental Learning方向许多工作的基石。文章的主要贡献包括:

给Class-Incremental Learning的设定一个规范的定义:


模型需要在一个新的类别不断出现的流式数据中进行训练。

模型在任意阶段,都应该能够对目前见到的所有类别进行准确的分类。

型的计算消耗和存储消耗必须设置一个上界或者只会随着任务数量缓慢增长。


第一次阐明了我们可以在将来的训练阶段保留一部分典型的旧类数据,这极大地提高了模型能够实现的准确率上限,并提出了一种有效的典型样本挑选策略herding:贪心的选择能够使得exemplar set 的特征均值距离总的均值最近的样本。

Herding

提出了使用保留的旧类数据来进行nearest-mean-of-exemplars的分类方式,而非直接使用训练阶段的到的线性分类器。这是因为使用交叉熵损失函数在不平衡的数据集上直接进行训练,很容易出现较大的分类器的偏执。而模型提取的特征则能够很大程度上缓解这个问题。


训练流程

当新的任务到来时:

将新来的类别数据集与保留的旧类数据的exemplar set合并得到当前轮的数据集。

使用sigmoid将模型输出的logits转化为0-1之间。将目标标签转化为one-hot向量表示。

对于新类的分类,我们使用binary cross entropy来计算损失。这里的binary cross entropy的计算仅仅考虑了所有的新的类别的计算,这种方式能够使得我们在学习新的样本的时候,不会更新旧的线性分类器中的权重向量,从而减少不均衡的数据流对样本分类的影响。

对于旧类的分类,则仿照LwF的模式,计算新旧模型在旧类上的概率输出的binary cross entropy的损失来训练模型。

iCaRLiCaRL对后来的方法的影响颇深。在此之后,相当数量的类别增量学习方法都仿照这一范式。创建一个exemplar set来存储典型的旧类样本。使用知识蒸馏来提供旧类的监督信息。

5.3 BiC[8]


核心摘要

BiC基本遵循了iCaRL的训练范式,但仍然使用线性分类器作为预测阶段的分类器。BiC指出,类别增量学习中出现的灾难性的遗忘,很重要的一个因素是由于训练集样本不均衡导致的分类器偏执。文中抽象地解释了这种训练样本不均衡导致的分类器的偏执的原因。如下图所示,图中的蓝色虚线是所有的旧类特征无偏的分布(Unbiased Distribution),绿色实线为新类样本的无偏分布,图中的蓝色实现则对应了无偏的分类器。而由于在学习新的类别的时候,我们仅仅保有一部分的旧类样本。这就导致实际训练过程中我们遇到的特征分布,可能是如蓝色实线一样的狭窄尖锐的分布,这就导致我们学习得到的分类器也会相对无偏的分类器向右偏移,导致有很大一部分旧类样本被分为新类了。BiC依照这种思路,BiC设置了一种Bias Correction的阶段,我们使用线性偏移来将新类的分布进行拉平与平移,从而使得实现与虚线重叠,得到无偏的分类器,具体的:在使用类似于iCaRL的训练模式训练完成后,我们使用预先保留的新旧类平衡的训练集来训练两个参数,分别控制分类器的平移于缩放,即:

将旧类的输出乘上并加上, 其中由Bias Correction的训练阶段得到。

5.4 WA[9]

核心摘要

WA 中指出,直接在新的数据上对模型进行finetune导致模型性能下降的原因主要有两个:

没有足够的旧类样本来进行训练,导致模型不能够保持旧类内部之间的分辨能力。

旧类样本显著少于新类样本,导致模型出现了极大的分类偏执,这种bias导致模型无论遇到旧类样本还是新类样本,都会在新类的概率输出上给出一个较大的值。

问题剖析因此,WA将该过程分为两个目标:

保持旧类之间的相对大小:即Maintaining Discrimination

处理新旧类的公平问题,即实现新旧类分类偏好的对齐:Maintaining Fairness

目标对于第一个目标,基于iCaRL模式的知识蒸馏能够很好的实现旧类的之间的Discrimination的问题。对于第二个目标,WA中发现,在不平衡的训练集上训练之后,新类所对应的的线性分类器往往具有相较于旧类线性分类器更大的权值。比较新旧类分类器权值的L2范数可以发现,新类的L2范数显著大于旧类。分类器权重比较而WA所做的则是将该L2范数拉平。

其中:

最终的实验结果证明这种简单的策略性能提升非常显著。实验结果我们在这里需要指明的一点是,这种解决思路其实并没有那么完备。因为分类器权值的大小,并不总与最终输出的logits的大小正相关。这是因为如果一个分类器对应的权值很大,那么如果一个特征与其是同向的,那么logits的大小显然与该分类器权值为正相关,但是如果一个样本的特征是与该分类器的方向是反向的,则特征与分类器的内积将会是一个很小的负值,此是则权值为负相关。因此,分类器权值很大,并不总意味着模型输出的logits很大。那么为什么这种解决方案能够很好的解决新旧类的calibration的问题呢?作者在文中给了intuitive的解释:由于现在模型结构中往往使用了非负的激活函数,典型的如relu,导致模型的特征输出,分类器的权值往往都是正值,这意味着分类器权值向量与特征向量在大多数情况下的夹角都是锐角,其内积为正数,因此是正相关的。

5.5 DER[10]

基于动态特征结构的方法已经被广泛应用于解决在Task-IL中,DER是首个将动态特征结构方法应用于Class-IL的场景下,并取得优异性能的尝试。

核心摘要

DER中说明,传统的方法会陷入稳定性-可塑性困境 (stability-plasticity dilemma): 对于一个单骨架的模型,如果不施加任何限制,给它足够的可塑性,那么它在旧类样本上的表现就会产生大幅度降低;但是如果施加过多的限制则又会导致模型没有足够的可塑性来学习新的任务。而DER则实现了相较于传统方法更好的稳定性可塑性的trade-off。DER保留并冻结旧的特征提取器来保留旧的知识,同时创建一个新的可训练的特征提取使模形具有足够的可塑性来学习新的任务。DER

训练流程

具体的,当新的任务到来时:

DER固定住原有的特征提取器,并创建新的特征提取器,将两特征拼接得到总的特征提取器

提取得到的特征送入新创建的分类器,并计算与目标的交叉熵损失。

为了更好地提取特征,DER另外使用了一个辅助分类器,仅仅使用新的特征,并要求新的特征空间能够良好的实现新类之间的辨别。而对于所有的旧类样本,辅助分类器会将他们分类到同一个标签上面。

DER还设计了一种剪枝的方式,能够在尽可能保持模型性能的基础上实现大幅度的参数削减。这种剪枝策略从Task-IL的经典方法HAT[11]中借鉴而来,将HAT的以filter的权值的掩码,转变成整个channel的掩码。

最终模型训练的损失函数为:

每个损失分别为:总的交叉熵损失、辅助分类器交叉熵损失、剪枝策略对应的稀疏解损失。

5.6 COIL[12]

经典的学习系统往往被部署在封闭环境中,学习模型可以利用预收集的数据集对固定类别的数据进行建模。然而,在开放动态环境中这种假设难以满足——新的类别会随时间不断增长,模型需要在数据流中持续地学习新类。例如,在电商平台中,每天都会新增多种产品;在社交媒体上,新的热点话题层出不穷。因此,类别增量学习模型需要在学习新类的同时不遗忘旧类别的特征。COIL观察到在增量学习的过程中,新类和旧类间存在相关性,因此可以利用它来进一步地辅助模型在不同阶段的学习。因此,COIL提出利用协同运输辅助类别增量学习过程,并基于类别间的语义相关性将不同的增量学习阶段联系起来。协同运输分为两方面:向前运输(prospective transport)旨在利用最优运输获得的知识增广分类器,作为新类分类器的初始化;向后运输(retrospective transport)旨在将新类分类器转化为旧类分类器,并防止灾难性遗忘。因此模型的知识可以在增量学习过程中双向流动,从而在学习新类的同时保持对旧类的判别能力。COIL的特征层面说明如上图所示,COIL尝试基于类别间的语义关系进行分类器迁移。例如,老虎和猫很相似,因此用于判别二者的特征也高度重合,甚至可以重用大量老虎的分类器权重作为类别猫的分类器初始化;老虎和斑马不相似,因此用于判别二者的特征也无法重用。COIL考虑在统一的嵌入空间下度量类别中心的相似关系,并以此构造类别间的距离矩阵。之后,借助最优运输算法,将类别之间的距离作为运输代价,最小化所有新类和旧类集合之间的分类器重用代价,从而基于类别之间的语义关系指导分类器重用。最后,如下图所示,分别将旧类分类器复用为新类分类器,和将新类分类器复用为旧类分类器,构造两个不同方向的知识迁移,并以此设计了损失函数用于约束模型,防止灾难性遗忘。COIL方法实现分类边界可视化:COIL分类边界

相关文章
|
3月前
|
自然语言处理
预训练模型STAR问题之开放信息抽取(OpenIE)目标的问题如何解决
预训练模型STAR问题之开放信息抽取(OpenIE)目标的问题如何解决
|
机器学习/深度学习 自然语言处理 安全
【网安专题11.8】14Cosco跨语言代码搜索代码: (a) 训练阶段 相关程度的对比学习 对源代码(查询+目标代码)和动态运行信息进行编码 (b) 在线查询嵌入与搜索:不必计算相似性
【网安专题11.8】14Cosco跨语言代码搜索代码: (a) 训练阶段 相关程度的对比学习 对源代码(查询+目标代码)和动态运行信息进行编码 (b) 在线查询嵌入与搜索:不必计算相似性
260 0
|
6月前
|
自然语言处理 Python
【相关问题解答1】bert中文文本摘要代码:import时无法找到包时,几个潜在的原因和解决方法
【相关问题解答1】bert中文文本摘要代码:import时无法找到包时,几个潜在的原因和解决方法
54 0
|
6月前
|
自然语言处理 数据挖掘 Java
20源代码模型的数据增强方法:克隆检测、缺陷检测和修复、代码摘要、代码搜索、代码补全、代码翻译、代码问答、问题分类、方法名称预测和类型预测对论文进行分组【网安AIGC专题11.15】
20源代码模型的数据增强方法:克隆检测、缺陷检测和修复、代码摘要、代码搜索、代码补全、代码翻译、代码问答、问题分类、方法名称预测和类型预测对论文进行分组【网安AIGC专题11.15】
289 0
利用abbrevr包批量输出期刊缩写
有时候用endnote导入文献后显示的是期刊全称,而用到缩写时候就需要去一些网站上一个个搜索,比如CASSI, LetPub、Pubumed等网站,或者Y叔公号里直接回复,而逛Github时候突然发现abbrevr这小R包中可以很快批量实现这个需求,在此记录一下。
106 2
|
算法 PyTorch 算法框架/工具
万文长字总结「类别增量学习」的前世今生、开源工具包(3)
万文长字总结「类别增量学习」的前世今生、开源工具包
371 0
|
机器学习/深度学习 存储 算法
万文长字总结「类别增量学习」的前世今生、开源工具包(1)
万文长字总结「类别增量学习」的前世今生、开源工具包
168 0
|
机器学习/深度学习 人工智能 算法
纠错数据标注,只需一行代码:开源项目Cleanlab发布了2.0版本
纠错数据标注,只需一行代码:开源项目Cleanlab发布了2.0版本
215 0
|
机器学习/深度学习 编解码 自然语言处理
错字修改 | 布署1个中文文文本拼蟹纠错模型
错字修改 | 布署1个中文文文本拼蟹纠错模型
309 0
Argo 数据集下载地址-具体到每天数据(包含数据说明书)
将所用的Argo数据下载地址和一键下载方式分享给大家
Argo 数据集下载地址-具体到每天数据(包含数据说明书)
下一篇
无影云桌面