万文长字总结「类别增量学习」的前世今生、开源工具包(1)

简介: 万文长字总结「类别增量学习」的前世今生、开源工具包

随着统计机器学习的逐渐成熟, 现在已经是时候打破孤立学习地传统模式,转而研究终身学习, 将机器学习推向崭新的高度。


一、什么是终身学习(Life-Long Machine Learning)?


终身机器学习(或称终身学习)是一种高级的机器学习范式, 它通过不断学习,从过去的任务当中积累知识,并用这些知识帮助未来的学习。在这样的过程中,学习者的知识越来越丰富,学习效率也越来越高。这种学习能力的特质是人类智力的重要标志。然而, 当前主流的机器学习范式是孤立学习的:给定训练数据集, 算法直接通过该训练集来生成模型(从假设空间中搜索最优或近似最优的假设)。它不会试图保留所学的知识,来提高未来的学习效率。虽然这种孤立学习范式已经取得了很大的成功, 但它需要大量的训练实例,并且只适用于定义明确而且范围狭窄的极其有限任务。相比之下, 我们人类则可以通过几个例子有效地学习, 这是因为我们在过去积累了如此多的知识。这种积累的先验知识使我们能够通过很少的数据或者付出较小的努力来高效地学习新的事物。终身学习旨在使机器学习模型具备这种能力。随着统计机器学习的逐渐成熟, 现在已经是时候打破孤立学习地传统模式,转而研究终身学习, 将机器学习推向崭新的高度。智能助手、聊天机器人和物理机器人等应用也都在迫切地需要这种终身学习能力。果没有积累所学知识并使用它来逐步学习更多知识的能力, 一个系统可能永远不会是真正的智能系统[1]Life Long Learning近年来, 终身学习(LLL)在深度学习界引起了极大的关注, 它通常被称为持续学习(Continual Learning)。虽然度神经网络(DNNs)在许多机器学习任务中取得了最好的性能, 但基于联结主义的深度学习算法存在着灾难性的遗忘的问题, 这使得实现持续学习的目标变得非常困难。当使用神经网络学习序列任务时, 模型在学习新的任务后可能因为灾难性遗忘的问题而导致模型在旧任务的表现变得很差。然而,我们的人脑却有这种非凡的能力, 能够学习大量不同的任务, 而不会出现任何负面的相互干扰。 持续学习(Continual Learning)算法试图为神经网络实现同样的能力, 并解决灾难性的遗忘问题。 因此, 从本质上讲, 持续学习执行的是对新任务的增量学习(Incremental Learning)。然而与许多其他Life Long Learning技术不同, 当前持续学习算法的重点并不是如何利用在以前任务中学到的知识来帮助更好地学习新任务。而是重点在于解决灾难性遗忘的问题。

二、什么是灾难性的遗忘[2]


灾难性的遗忘指的是,模型学习了新的知识之后,几乎彻底遗忘掉之前训练的内容。这样一个问题简言之,关注的是在Sequential Learning过程中,模型在每个学习阶段,都会接触到新的不同的数据或任务,而对于旧类数据失去或者仅有有限的访问权限。在这样的一个场景下,以神经网络为代表的联结主义模型在旧的任务上的性能会大大降低。例如在传统的图像分类模型的训练过程中,我们同时在所有数据上进行训练 (尽管目前流行的优化方法都是以Batch的形式分批优化,但是每个epoch我们仍然会在所有数据上进行训练,相当于每个epoch都会对所有的数据进行温习)。但是在continual learning的情况下,学习的任务是分task依次训练的,当我们在新的task上进行训练时,旧的task中的训练数据是不可获得(或者获得受限的)。举个简单的例子:小明是一名大学生,现在需要参加期末考试,考试科目有数字信号处理、机器学习、近代史。众所周知,因为小明是一个大学生的原因,他的记忆能力非常古怪。他可以两个小时速成一门学科,但是如果他复习了新的学科,他就会忘记旧的学科。第一天,小明复习了数字信号处理,他高兴地走出考场。第二天,小明复习了机器学习,他高兴地走出考成。第三天,小明复习了近代史,但是学校发现有人泄露了题目,于是决定三门考试同时重考。于是,小明高兴地把近代史的内容“精准地”写在了数字信号处理与机器学习的卷子上面,留下阅卷老师一脸疑惑。尽管上面的故事非常荒唐,但是在目前研究的场景中经常会出现类似的问题,比如在分类任务中,我们首先使用一些预定类别的样本训练一个模型,之后再使用一些新类别的样本来finetune这样一个网络,这会使网络识别初始类别的性能大幅度下降;再比如,在增强学习任务中,单独训练后续的任务,会使agent在前序任务的性能下降严重。如下图所示,当我们使用神经网络模型训练新的任务鱼和老虎时,模型却错误的将旧任务当中的狗分类成了鱼。

Catastrophic Forgetting

三、Continual Learning 有哪些场景?


场景一:Task-IL

任务增量学习,是最简单的Continual Learning的场景。在这种场景下,无论是训练阶段还是测试阶段,模型都被告知了当前的任务ID。这种特性导致了一些task specific component的方法出现,如packNet[3]提前为每个任务确定卷积的filter的掩码图。再如HAT会动态的根据任务为卷积训练掩码图。当给定任务ID后,则选择相应的掩码进行预测。PackNet

场景二:Domain-IL

Domain-IL相较于Task-IL在测试阶段增加了新的限制,即在预测阶段并不会告知任务的ID。模型需要在不知道任务ID的情况下,将数据正确的分类。Domain-IL的场景,常常用来处理标签空间相同,但输入分布不同的问题。例如动漫中的老虎和现实中的老虎(虎年彩蛋)。domain(一):真实世界中的老虎domain(二):动画老虎

场景三:Class-IL

在Class-IL中新的类别不断地到来,模型需要正确地将输入分类到其对应地类别当中去。这是更为严格的场景,模型在接受输入后,需要正确的识别输入对应的task-ID,然后将数据粉到正确的类别当中去。

举例[4]

下图展示了一个形象的例子,模型依次在task1、2、3、4、5上进行训练。在预测阶段Task-IL会告知task-ID,模型根据task-ID将数据分为第一类或者第二类。例如当告知task-ID为1时,模型只需要在0和1之间进行辨别。Domain-IL无法获得task-ID, 但是它需要判断输入的标签是属于集合(0,2,4,6,8)还是(1,3,5,7,9).而Class-IL需要给出具体的数字标签,即从0-9之间选择一个进行输出。MNIST上的Incremental Learning的三种不同setting

此外,目前还有更为严格地data-IL, 我们在训练时就并不显示的告知task的阶段,希望模型能够适应这种类别不稳定不均衡的数据流。此处我们不展开讨论。


四、什么是类增量学习?


一个简单的例子

Class-Incremental Learning 举例模型首先在任务 1 上进行训练, 学习分类鸟类和水母。之后,需要基于当前模型分别在任务 2 中学习鹅类和北极狐,在任务 3 中学习狗类和螃蟹。在顺序化地完成训练后,模型需要在所有已经见过的类别上进行评估,一个好的类别增量模型应该能既学得新类知识又不遗忘旧类知识。

形式化定义


类别增量学习旨在从一个数据流中不断学习新类。假设存在B个不存在类别重合的训练集, 其中表示第 b 个增量学习训练数 据集, 又称作训练任务 (task)。 是来自于类别的一个训练样本, 其中是第 b 个任务的标记空间。不同任务间不存在类别重合, 即对于 有: 。在学习第 b 个任务的过程中,只能使用当前阶段的训练数据集 更新模型. 在每个训练阶段, 模型的目标不仅是学得当前 数据集 中新类的知识,同时也要保持不遗忘之前所有学过类别的知识. 因此, 我们基于模型在所 有已知类集合 上的判别能力评估其增量学习能力. 将增量学习模型对样本 的输出 记作 , 则模型要优化的期望风险描述为:

其中 表示第 b 个任务的样本分布。 评估输入之间的差异, 在分类任务中一般使用交叉熵损 失函数。由于模型需要同时在见过的所有分布上最小化期望风险,能够满足公式 1 的模型能够在学习 新类的同时不遗忘旧类的知识。进一步地, 可以将深度神经网络按照特征提取和线性分类器层进行解耦, 则模型 由特征提取模块 和线性分类器 组成,即 为了表述方便, 我们将线性分类器 进一步表示成对于每个类分类器 的组合:

五、论文方法解读


模型解耦

为了方便之后的说明,我们首先对神经网络模型进行解耦。模型 由特征提取模块 和线性分类器 组成, 即 。为了表述方便, 我们将线性分类器 进一步表示成对于每个类分类器 的组合:

5.1 LwF: Learning without Forgetting[5]

核心摘要

LwF(Learning without Forgetting) 是Incremental Learning领域早期发表的一篇文章,论文的核心要点包括:

除了LwF本身外,还提出了Fine-tunine, Feature Extraction, Joint Training三种基本的对比方法,并对不同方法进行了分析与实验对比。

提出了使用知识蒸馏(Knowledge Distillation)的方法提供旧类的“软监督”信息来缓解灾难性遗忘的问题。并且发现,这种监督信息即使没有旧类的数据仍然能够很大程度上提高旧类的准确率。

对参数偏移的正则惩罚系数、正则形式、模型拓展方式等等因素进行了基本的实验对比。(不过具论文中结果这些因素的影响并不明显)。

方法比较

Learning Without Forgetting如图中所示:

(a) 中为传统的多分类模型,它接受一张图片,然后通过线性变换、非线性激活函数、卷积、池化等运算操作符输出该图片在各个类别上的概率。

(b) 中为Fine-tuning方法,即训练新类时,我们保持旧的分类器不变,直接训练前面的特征提取器和新的分类器权重。

(c) 称为Feature Extraction,保持特征提取器不变,保持旧的分类器权重不变,只训练新的任务对应的参数。

(d) 中为Joint Training的方法,它在每个训练任务时刻都同时接受所有的训练数据进行训练。

(e) 中为LwF方法,他在Fine-tuning的基础上,为旧类通过知识蒸馏提供了一种“软”监督信息。

知识蒸馏(Knowledge Distillation)[6]

知识蒸馏(Knowledge Distilling)最初是模型压缩的一种方法,是指利用已经训练的一个较复杂的Teacher模型,指导一个较轻量的Student模型训练,从而在减小模型大小和计算资源的同时,尽量保持原Teacher模型的准确率的方法。其基本的形式为:

其中为第i类的logits输出, 为温度系数。知识蒸馏的损失函数可以看作是最小化Teacher模型和Student模型在已有数据集上数据似然的KL散度。这种监督信息相较于一般的标签分布一方面更加的平滑,另外一方面能够一定程度上反应不同类别之间的相似关系。在LwF的模型中,我们使用额外的内存开销保存旧的模型,当训练新的模型时,使用旧的模型作为旧类的Teacher模型。

训练流程

对于新的任务的训练集,LwF的损失函数包括:

新类的标签监督信息:即新类对应的logits与标签的交叉熵损失(KL散度)

旧类的知识蒸馏:新旧模型在旧类上的的logits的交叉熵损失(包含温度系数:设置温度系数大于一,来增强模型对于类别相似性的编码能力)

参数偏移正则项,对于新模型参数关于旧模型参数偏移的正则项。

具体的伪代码如下:


5.2 iCaRL: Incremental Classifier and Representation Learning[7]


核心摘要

iCaRL可以视为Class-Incremental Learning方向许多工作的基石。文章的主要贡献包括:

给Class-Incremental Learning的设定一个规范的定义:


模型需要在一个新的类别不断出现的流式数据中进行训练。

模型在任意阶段,都应该能够对目前见到的所有类别进行准确的分类。

型的计算消耗和存储消耗必须设置一个上界或者只会随着任务数量缓慢增长。


第一次阐明了我们可以在将来的训练阶段保留一部分典型的旧类数据,这极大地提高了模型能够实现的准确率上限,并提出了一种有效的典型样本挑选策略herding:贪心的选择能够使得exemplar set 的特征均值距离总的均值最近的样本。

Herding

提出了使用保留的旧类数据来进行nearest-mean-of-exemplars的分类方式,而非直接使用训练阶段的到的线性分类器。这是因为使用交叉熵损失函数在不平衡的数据集上直接进行训练,很容易出现较大的分类器的偏执。而模型提取的特征则能够很大程度上缓解这个问题。


训练流程

当新的任务到来时:

将新来的类别数据集与保留的旧类数据的exemplar set合并得到当前轮的数据集。

使用sigmoid将模型输出的logits转化为0-1之间。将目标标签转化为one-hot向量表示。

对于新类的分类,我们使用binary cross entropy来计算损失。这里的binary cross entropy的计算仅仅考虑了所有的新的类别的计算,这种方式能够使得我们在学习新的样本的时候,不会更新旧的线性分类器中的权重向量,从而减少不均衡的数据流对样本分类的影响。

对于旧类的分类,则仿照LwF的模式,计算新旧模型在旧类上的概率输出的binary cross entropy的损失来训练模型。

iCaRL iCaRL对后来的方法的影响颇深。在此之后,相当数量的类别增量学习方法都仿照这一范式。创建一个exemplar set来存储典型的旧类样本。使用知识蒸馏来提供旧类的监督信息。

5.3 BiC[8]


核心摘要

BiC基本遵循了iCaRL的训练范式,但仍然使用线性分类器作为预测阶段的分类器。BiC指出,类别增量学习中出现的灾难性的遗忘,很重要的一个因素是由于训练集样本不均衡导致的分类器偏执。文中抽象地解释了这种训练样本不均衡导致的分类器的偏执的原因。如下图所示,图中的蓝色虚线是所有的旧类特征无偏的分布(Unbiased Distribution),绿色实线为新类样本的无偏分布,图中的蓝色实现则对应了无偏的分类器。而由于在学习新的类别的时候,我们仅仅保有一部分的旧类样本。这就导致实际训练过程中我们遇到的特征分布,可能是如蓝色实线一样的狭窄尖锐的分布,这就导致我们学习得到的分类器也会相对无偏的分类器向右偏移,导致有很大一部分旧类样本被分为新类了。BiC依照这种思路,BiC设置了一种Bias Correction的阶段,我们使用线性偏移来将新类的分布进行拉平与平移,从而使得实现与虚线重叠,得到无偏的分类器,具体的:在使用类似于iCaRL的训练模式训练完成后,我们使用预先保留的新旧类平衡的训练集来训练两个参数,分别控制分类器的平移于缩放,即:

将旧类的输出乘上并加上, 其中由Bias Correction的训练阶段得到。

5.4 WA[9]

核心摘要

WA 中指出,直接在新的数据上对模型进行finetune导致模型性能下降的原因主要有两个:

没有足够的旧类样本来进行训练,导致模型不能够保持旧类内部之间的分辨能力。

旧类样本显著少于新类样本,导致模型出现了极大的分类偏执,这种bias导致模型无论遇到旧类样本还是新类样本,都会在新类的概率输出上给出一个较大的值。

问题剖析 因此,WA将该过程分为两个目标:

保持旧类之间的相对大小:即Maintaining Discrimination

处理新旧类的公平问题,即实现新旧类分类偏好的对齐:Maintaining Fairness

类。分类器权重比较而WA所做的则是将该L2范数拉平。

其中:

最终的实验结果证明这种简单的策略性能提升非常显著。实验结果我们在这里需要指明的一点是,这种解决思路其实并没有那么完备。因为分类器权值的大小,并不总与最终输出的logits的大小正相关。这是因为如果一个分类器对应的权值很大,那么如果一个特征与其是同向的,那么logits的大小显然与该分类器权值为正相关,但是如果一个样本的特征是与该分类器的方向是反向的,则特征与分类器的内积将会是一个很小的负值,此是则权值为负相关。因此,分类器权值很大,并不总意味着模型输出的logits很大。那么为什么这种解决方案能够很好的解决新旧类的calibration的问题呢?作者在文中给了intuitive的解释:由于现在模型结构中往往使用了非负的激活函数,典型的如relu,导致模型的特征输出,分类器的权值往往都是正值,这意味着分类器权值向量与特征向量在大多数情况下的夹角都是锐角,其内积为正数,因此是正相关的。

5.5 DER[10]

基于动态特征结构的方法已经被广泛应用于解决在Task-IL中,DER是首个将动态特征结构方法应用于Class-IL的场景下,并取得优异性能的尝试。

核心摘要

DER中说明,传统的方法会陷入稳定性-可塑性困境 (stability-plasticity dilemma): 对于一个单骨架的模型,如果不施加任何限制,给它足够的可塑性,那么它在旧类样本上的表现就会产生大幅度降低;但是如果施加过多的限制则又会导致模型没有足够的可塑性来学习新的任务。而DER则实现了相较于传统方法更好的稳定性可塑性的trade-off。DER保留并冻结旧的特征提取器来保留旧的知识,同时创建一个新的可训练的特征提取使模形具有足够的可塑性来学习新的任务。DER


相关文章
|
机器学习/深度学习 自然语言处理 安全
【网安专题11.8】14Cosco跨语言代码搜索代码: (a) 训练阶段 相关程度的对比学习 对源代码(查询+目标代码)和动态运行信息进行编码 (b) 在线查询嵌入与搜索:不必计算相似性
【网安专题11.8】14Cosco跨语言代码搜索代码: (a) 训练阶段 相关程度的对比学习 对源代码(查询+目标代码)和动态运行信息进行编码 (b) 在线查询嵌入与搜索:不必计算相似性
259 0
|
6月前
|
自然语言处理 Python
【相关问题解答1】bert中文文本摘要代码:import时无法找到包时,几个潜在的原因和解决方法
【相关问题解答1】bert中文文本摘要代码:import时无法找到包时,几个潜在的原因和解决方法
53 0
利用abbrevr包批量输出期刊缩写
有时候用endnote导入文献后显示的是期刊全称,而用到缩写时候就需要去一些网站上一个个搜索,比如CASSI, LetPub、Pubumed等网站,或者Y叔公号里直接回复,而逛Github时候突然发现abbrevr这小R包中可以很快批量实现这个需求,在此记录一下。
106 2
|
存储 算法 数据可视化
万文长字总结「类别增量学习」的前世今生、开源工具包(2)
万文长字总结「类别增量学习」的前世今生、开源工具包
236 0
|
算法 PyTorch 算法框架/工具
万文长字总结「类别增量学习」的前世今生、开源工具包(3)
万文长字总结「类别增量学习」的前世今生、开源工具包
371 0
|
机器学习/深度学习 人工智能 算法
纠错数据标注,只需一行代码:开源项目Cleanlab发布了2.0版本
纠错数据标注,只需一行代码:开源项目Cleanlab发布了2.0版本
212 0
|
JavaScript 开发工具 开发者
(简易)测试数据构造平台:33 - 正文开始-工具使用功能
(简易)测试数据构造平台:33 - 正文开始-工具使用功能
|
自然语言处理
歧义代词数据集有哪些公开数据集的下载方式
Winograd Schema Challenge (WSC)数据集的下载网站是:https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WSCollection.xml。
211 0
|
机器学习/深度学习 编解码 自然语言处理
错字修改 | 布署1个中文文文本拼蟹纠错模型
错字修改 | 布署1个中文文文本拼蟹纠错模型
307 0
Argo 数据集下载地址-具体到每天数据(包含数据说明书)
将所用的Argo数据下载地址和一键下载方式分享给大家
Argo 数据集下载地址-具体到每天数据(包含数据说明书)