如下图所示,传统的图像分类任务是基于左边的给定训练数据,获得model
,然后在右边的数据集上测试model
的好坏。
而对于小样本问题,其训练数据和测试数据如下所示:
我们拥有的是大量的上方这些数据,也就是对于training
中的airplane
、automobile
等,我们有很多类数据,而对于下方Testing
中像dog
、frog
等新的分类问题,就没有那么多类的标注数据。
在了解Matching Networks
之前,先要理解一下One-Shot Learning
中的一个非常基础的概念N-way K-shot
。
N-way K-shot
:从Meta-dataset
中随机抽取N
类样本,(更简单的说法就是Support set
中的类别数量,其label
的组成通常称之为label set
),每类样本随机抽取K+1
个实例。其中每类样本中抽取K
个实例组成Support set
,剩下的实例组成Test set
(通常为了区分真正的testing
,将其称之为Query Set
)。
在特殊情况下:
K=1
,称之为One-Shot Learning
。K=0
,称之为Zero-Shot Learning
。
背景
人类能够利用已有的先验知识对为见过的类别,只需要少量数据就可以学到。较早的研究few-shot learning
的文章:
- Li Fe-Fei, Robert Fergus, and Pietro Perona. A bayesian approach to unsupervised one-shot learning of object categories. In Computer Vision, 2003. Proceedings. Ninth IEEE International Conference on, pages 1134. 1141. IEEE, 2003.
- Li Fei-Fei, Robert Fergus, and Pietro Perona. One-shot learning of object categories. Pattern Analysis and Machine Intelligence, IEEE Transactions on, 28(4):594{611, 2006.
和通过因果关系的:
- Brenden M Lake, Ruslan R Salakhutdinov, and Josh Tenenbaum. One-shot learning by inverting a compositional causal process. In Advances in neural information processing systems, pages 2526{2534, 2013.
- Brenden M Lake, Chia-ying Lee, James R Glass, and Joshua B Tenenbaum. One-shot learning of generative speech concepts. Cognitive Science Society, 2014.
传统的计算机视觉处理图像的方法主要就是基于特征的学习和基于度量的学习,像基于度量的就是将特征编码到隐空间,早些年可能也有用核函数的方法。也有通过因果关系的,更多的是从可解释性的手段上做这些事情。
Metrics-Based Methods
Metrics-Based Methods
最主要的还是基于小样本学一个数据的表示,同时机器学习中也有一个专门的分支做这件事情:表示学习,Representation Learning
。
传统的参数学习期望是给定输入X XX预测输出Y YY,通过最小化损失函数来做到这件事情,常见的损失函数有以下几类:
Siamese Network
- 论文题目:Siamese Neural Networks for One-Shot Image Recognition
机器学习中当只有少量数据时,获取好的特征表征是非常难做到的一件事情。作者提出孪生网络,去获取输入特征之间的相似度排序,一旦这种抽取特征的网络得到,就可以应用于新的类别数据,依据相似度去做分类。这样做的好处在于,它不仅能够在新的数据上做分类,非当前这个分布下的数据也能够很好地分类。
孪生网络首次被提出是在下面这篇论文中:
- Jane Bromley, James W Bentz, Leon Bottou, Isabelle Guyon, Yann LeCun, Cli Moore, Eduard Sackinger, and Roopak Shah. Signature veri cation using a siamese time delay neural network. International Journal of Pattern Recognition and Arti cial Intelligence, 7(04):669{688, 1993.
Siamese Network
通过特殊的loss
函数学会去区分给定的两个输入是否相同,由两个参数一样的神经网络组成,其网络结构如下图所示:
核心思想是将输入编码到一个隐空间,有点类似迁移学习,但是不同之处在于Siamese Network
是通过contrastive loss function
来做到这一点的。
为什么要用两个参数一样的神经网络来做呢?两个网络参数一样能够学地更快,并且能够将其编码到相同的特征空间中。
在数据预处理部分,如果两张图片是同一类,我们需要将其标为1,否者标签为0。损失函数方面,常用两类损失函数:contrastive loss function
和triplet loss function
。
- Contrastive Loss Function:
其中表示Siamese Network,m mm表示margin,是为了使得label=1时,期望D w 的输出为0,这样就很容易将网络的权重W WW也学成0,因此加一个margin。
- Triplet loss function
Triplet loss function的效果一般比Contrastive Loss Function还要好一点。因为其将正例、负例样本都有考虑进loss function:
- Anchor ( A ): The main data point。
- Positive ( P ): A data point similar to Anchor。
- Negative ( N ):A different data point than Anchor。
如果用距离度量的话,我们期望测试样本距离正例的距离小于负例的距离,可表示为:
其中α也是用于控制网络对任意输入输出都为0
的这种情况。因此其损失函数可表示为:
Matching Network
- 论文题目:Matching Networks for One Shot Learning
尽管机器学习已经取得了很大的成功,但是对于给定少量数据快速学习new concepts
的能力还是欠缺。作者提出了一种新的网络框架来解决这个问题。
non-parametric
的方法(e.g., nearest neighbors
)可以很快学习到一种样本之间的度量方式,进而将样本分类。基于(S Roweis, G Hinton, and R Salakhutdinov. Neighbourhood component analysis. NIPS, 2004.)这篇文章,作者提出了matching network
,融合parametric
和non-parametric
的方式。
Matching Network
将输入和标签编码到一个空间中,测试的输入也将其编码到这个空间,之后计算余弦相似度,我们就可以得到匹配信息,从而进行预测。
其中余弦相似度c cc计算的就是support set
和query set
在编码空间中的相似度。那如何将数据编码呢?图像领域可以采用VGG16
或者Inception
这种网络结构。
为了考虑整个support set
中样本之间的关系,我们可以考虑采用bi-directional Long Short-Term Memory
对g ( x i ) g(x_{i})g(xi)做优化,这样就考虑了support set
中样本的context
信息。
如果在训练的时候考虑了context
信息,那么由于希望编码到相同的空间,那么对于测试样本x ^,我们也希望去考虑support set
的context
信息,那f ( x ^ ) 就可以表示为:
也就是说support set
样本经过g ( S ) g(S)g(S)之后得到的输出可以修改query set
中样本的embedding
模型(通过一个固定步数的LSTM
对support set
做attention
,再结合emeddings ( x ^ ) \text {emeddings}(\hat{x})emeddings(x^),K KK是LSTM
的步数)。
Full Conditional Embedding g gg
将Full Conditional Embedding
g gg 打开,如下图所示:
首先将x i x_{i}xi编码成一个向量g ′ g^{\prime}g′,这个可以用VGG
或者Inception
这种网络做个特征提取即可,之后将其经过一个Bi-LSTM
,再将其三者求和。
Full Conditional Embedding f ff
将Full Conditional Embedding
f ff 打开,如下图所示:
先不考虑support set
s ss,Query Set
中的样本x ^ 先编码得到f ′ f^{\prime}f′,之后经过一个LSTM
得到h ^ ,再与f ′ ( x ^ )相加。
总结
上述过程可简化为以下几步: