Tree-CNN:一招解决深度学习中的「灾难性遗忘」

简介:

网络结构及学习策略

网络结构

Tree-CNN 模型借鉴了层分类器,树卷积神经网络由节点构成,和数据结构中的树一样,每个节点都有自己的 ID、父亲(Parent)及孩子(Children),网(Net,处理图像的卷积神经网络),LT("Labels Transform",就是每个节点所对应的标签,对于根节点和枝节点来说,可以是对最终分类类别的一种划分,对于叶节点来说,就是最终的分类类别),其中最顶部为树的根节点。

本文提出的网络结构如下图所示。对于一张图像,首先会将其送到根节点网络去分类得到“super-classes”,然后根据所识别到的“super-classes”,将图像送入对应的节点做进一步分类,得到一个更“具体”的类别,依次进行递推,直到分类出我们想要的类。

d98ef56a6d2694253fed60543d86de41071587ab

▲ 图1

其实这就和人的识别过程相似,例如有下面一堆物品:数学书、语文书、物理书、橡皮、铅笔。如果要识别物理书,我们可能要经历这样的过程,先在这一堆中找到书,然后可能还要在书里面找到理科类的书,然后再从理科类的书中找到物理书,同样我们要找铅笔的话,我们可能需要先找到文具类的物品,然后再从中找到铅笔。

学习策略

在识别方面,Tree-CNN 的思想很简单。如图 1 所示,主要就是从根节点出发,输出得到一个图像属于各个大类的概率,根据最大概率所对应的位置将识别过程转移到下一节点,这样最终我们能够到达叶节点,叶节点对应得到的就是我们要识别的结果。整个过程如图 2 所示。

84e15e5e9fbb437370387f7445d1f51c98524575

▲ 图2

如果仅按照上面的思路去做识别,其实并没有太大的意义,不仅使识别变得很麻烦,而且在下面的实验中也证明了采用该方法所得到的识别率并不会有所提高。而这篇论文最主要的目的就是要解决我们在前面提到的“灾难性遗忘问题”,即文中所说的达到“lifelong”的效果。

对于新给的类别,我们将这些类的图像输入到根节点网络中,根节点的输出为 OK×M×I,其中 K、M、I 分别为根节点的孩子数、新类别数、每类的图像数。

然后利用式(1)来求得每类图像的输出平均值 Oavg,再使用 softmax 来计算概率情况。以概率分布表示该类与根节点下面子类的相似程度。对于第 m 类,我们按照其概率分布进行排列,得到公式(3)。

c39e55499f0c6fc270f26028a7f09f27efc8f40e

根据根节点得到的概率分布,文中分别对下面三种情况进行了讨论:

d47e62d2b349aca45e42305ed6714efbe5ed61d9当输出概率中最大概率大于设定的阈值 ,则说明该类别和该位置对应的子节点有很大的关系,因此将该类别加到该子节点上;
d47e62d2b349aca45e42305ed6714efbe5ed61d9若输出概率中有多个概率值大于设定的阈值 ,就联合多个子节点来共同组成新的子节点;
d47e62d2b349aca45e42305ed6714efbe5ed61d9如果所有的输出概率值都小于阈值 ,那么就为新类别增加新的子节点,这个节点是一个叶节点。

同样,我们将会对别的支节点继续上面的操作。通过上面的这些操作,实现对新类别的学习,文中称这种学习方式为 incremental/lifelong learning。

实验方法与结果分析

在这部分,作者分别针对 CIFAR-10 及 CIFAT-100 数据集上进行了测试

实验方法

1. CIFAR-10

在 CIFAR-10 的实验中,作者选取 6 类图像作为初始训练集,又将 6 类中的为汽车、卡车设定为交通工具类,将猫、狗、马设为动物类,因此构建出的初始树的结构如图 3(a)所示

5c52b0ac81c1963f370ce4dcd8a19bf6ec951f7c

▲ 图3

具体网络结构如图 4 所示,根节点网络是包含两层卷积、两层池化的卷积神经网络,支节点是包含 3 层卷积的卷积神经网络。

6cae955c2b942e888038bb637430f54607c44667

▲ 图4

当新的类别出现时(文中将 CIFAR-10 另外 4 个类别作为新类别),按照文中的学习策略,我们先利用根节点的网络对四种类别的图片进行分类,得到的输出情况如图 5 所示,从图中可以看出,在根节点的识别中 Frog、Deer、Bird 被分类为动物的概率很高,Airplane 被分类为交通工具的概率较高。

0e449e6f7c4857f9363182349af8ed3fe1f94f78

▲ 图5

根据文中的策略,Frog、Deer、Bird 将会被加入到动物类节点,同样 Airplane 将会被加入到交通工具类节点。经过 incremental/lifelong learning 后的 Tree-CNN 的结构如图 3(b)所示。 具体训练过程如图 6 所示。

ed68b7aa5d3c79470a4859c5370b086eb2e793cc

▲ 图6

为了对比 Tree-CNN 的效果,作者又搭建了一个包含 4 层卷积的神经网络,并分别通过调节全连接层、全连接 +conv1、全连接 +conv1+conv2、全连接 +conv1+conv2+conv3、全连接 +conv1+conv2+conv3+conv4 的参数来进行微调。

2. CIFAR-100

对于 CIFAR-100 数据集,作者将 100 类数据分为 10 组,每组包含 10 类样本。在网络方面,作者将根节点网络的卷积层改为 3,并改变了全连接层的输出数目。

实验结果分析

在这部分,作者通过设置两个参数来衡量 Tree-CNN 的性能

844961a2b799297cf5c5683a9bb8e8189513fcda

其中,Training Effort 表示 incremental learning 网络的更改程度,即可以衡量“灾难性遗忘”的程度,参数改变的程度越高,遗忘度越强。

图 7 比较了在 CIFAR-10 上微调网络和 Tree-CNN 的识别效果对比,可以看出相对于微调策略,Tree-CNN 的 Training Effort 仅比微调全连接层高,而准确率却能超出微调全连接层 +conv1。

5968c63aef9667e3aa0f4f1f5a3ca29b7fb03dbe

▲ 图7

这一现象在 CIFAR-100 中表现更加明显。

460246eac44cbbcb327e7f5ffc3e7c11792c61c5

▲ 图8

从图 7、图 8 中可以看出 Tree-CNN 的准确率已经和微调整个网络相差无几,但是在 Training Effort 上却远小于微调整个网络。

从图 9 所示分类结果中可以看出,在各个枝节点中,具有相同的特性的类被分配在相同的枝节点中。这一情况在 CIFAR-100 所得到的 Tree-CNN 最终的结构中更能体现出来。

6e7009a70e2d0b38e236eb738a95e33bc254992b

除了一些叶节点外,在语义上具有相同特征的物体会被分类到同一支节点下,如图 10 所示。

d0bf213507cfcfef25c2cffc2c6d04d86b4b972a

▲ 图10

总结与分析

本文虽然在一定程度上减少了神经网络“灾难性遗忘”问题,但是从整篇文章来看,本文并没能使网络的识别准确率得到提升,反而,相对于微调整个网络来说,准确率还有所降低。

此外,本文搭建的网络实在太多,虽然各个子网络的网络结构比较简单,但是调节网络会很费时。


原文发布时间为:2018-04-25

本文作者:吴仕超

本文来自云栖社区合作伙伴“PaperWeekly”,了解相关信息可以关注“PaperWeekly”。

相关文章
|
7月前
|
机器学习/深度学习 自然语言处理 异构计算
Python深度学习面试:CNN、RNN与Transformer详解
【4月更文挑战第16天】本文介绍了深度学习面试中关于CNN、RNN和Transformer的常见问题和易错点,并提供了Python代码示例。理解这三种模型的基本组成、工作原理及其在图像识别、文本处理等任务中的应用是评估技术实力的关键。注意点包括:模型结构的混淆、过拟合的防治、输入序列长度处理、并行化训练以及模型解释性。掌握这些知识和技巧,将有助于在面试中展现优秀的深度学习能力。
238 11
|
7月前
|
机器学习/深度学习 数据可视化 算法框架/工具
深度学习第3天:CNN卷积神经网络
深度学习第3天:CNN卷积神经网络
|
6月前
|
机器学习/深度学习
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
【从零开始学习深度学习】23. CNN中的多通道输入及多通道输出计算方式及1X1卷积层介绍
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
【10月更文挑战第1天】深度学习中,模型微调虽能提升性能,但常导致“灾难性遗忘”,即模型在新任务上训练后遗忘旧知识。本文介绍弹性权重巩固(EWC)方法,通过在损失函数中加入正则项来惩罚对重要参数的更改,从而缓解此问题。提供了一个基于PyTorch的实现示例,展示如何在训练过程中引入EWC损失,适用于终身学习和在线学习等场景。
108 4
揭秘深度学习中的微调难题:如何运用弹性权重巩固(EWC)策略巧妙应对灾难性遗忘,附带实战代码详解助你轻松掌握技巧
|
2月前
|
机器学习/深度学习 人工智能 监控
深入理解深度学习中的卷积神经网络(CNN):从原理到实践
【10月更文挑战第14天】深入理解深度学习中的卷积神经网络(CNN):从原理到实践
169 1
|
2月前
|
机器学习/深度学习 存储 人工智能
深度学习之不遗忘训练
基于深度学习的不遗忘训练(也称为抗遗忘训练或持久性学习)是针对模型在学习新任务时可能会忘记已学习内容的一种解决方案。该方法旨在使深度学习模型在不断接收新信息的同时,保持对旧知识的记忆。
58 4
|
2月前
|
机器学习/深度学习 编解码 算法
【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5
【深度学习】经典的深度学习模型-01 开山之作:CNN卷积神经网络LeNet-5
43 0
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
4月前
|
机器学习/深度学习 人工智能 TensorFlow
深度学习中的卷积神经网络(CNN)原理与实践
【8月更文挑战第31天】在人工智能的浪潮中,深度学习技术以其强大的数据处理能力脱颖而出。本文将深入浅出地探讨卷积神经网络(CNN)这一核心组件,解析其在图像识别等领域的应用原理,并通过Python代码示例带领读者步入实践。我们将从CNN的基本概念出发,逐步深入到架构设计,最后通过一个简易项目展示如何将理论应用于实际问题解决。无论你是深度学习的初学者还是希望深化理解的实践者,这篇文章都将为你提供有价值的洞见和指导。
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
算法金 | 秒懂 AI - 深度学习五大模型:RNN、CNN、Transformer、BERT、GPT 简介
**RNN**,1986年提出,用于序列数据,如语言模型和语音识别,但原始模型有梯度消失问题。**LSTM**和**GRU**通过门控解决了此问题。 **CNN**,1989年引入,擅长图像处理,卷积层和池化层提取特征,经典应用包括图像分类和物体检测,如LeNet-5。 **Transformer**,2017年由Google推出,自注意力机制实现并行计算,优化了NLP效率,如机器翻译。 **BERT**,2018年Google的双向预训练模型,通过掩码语言模型改进上下文理解,适用于问答和文本分类。
163 9

热门文章

最新文章