少样本学习系列(一)【Metrics-Based Methods】

简介: 少样本学习系列(一)【Metrics-Based Methods】

  如下图所示,传统的图像分类任务是基于左边的给定训练数据,获得model,然后在右边的数据集上测试model的好坏。

  而对于小样本问题,其训练数据和测试数据如下所示:

  我们拥有的是大量的上方这些数据,也就是对于training中的airplaneautomobile等,我们有很多类数据,而对于下方Testing中像dogfrog等新的分类问题,就没有那么多类的标注数据。

  在了解Matching Networks之前,先要理解一下One-Shot Learning中的一个非常基础的概念N-way K-shot

  N-way K-shot:从Meta-dataset中随机抽取N类样本,(更简单的说法就是Support set中的类别数量,其label的组成通常称之为label set),每类样本随机抽取K+1个实例。其中每类样本中抽取K个实例组成Support set,剩下的实例组成Test set(通常为了区分真正的testing,将其称之为Query Set)。

  在特殊情况下:

  • K=1,称之为One-Shot Learning
  • K=0,称之为Zero-Shot Learning


背景


  人类能够利用已有的先验知识对为见过的类别,只需要少量数据就可以学到。较早的研究few-shot learning的文章:

  • Li Fe-Fei, Robert Fergus, and Pietro Perona. A bayesian approach to unsupervised one-shot learning of object categories. In Computer Vision, 2003. Proceedings. Ninth IEEE International Conference on, pages 1134. 1141. IEEE, 2003.
  • Li Fei-Fei, Robert Fergus, and Pietro Perona. One-shot learning of object categories. Pattern Analysis and Machine Intelligence, IEEE Transactions on, 28(4):594{611, 2006.

  和通过因果关系的:

  • Brenden M Lake, Ruslan R Salakhutdinov, and Josh Tenenbaum. One-shot learning by inverting a compositional causal process. In Advances in neural information processing systems, pages 2526{2534, 2013.
  • Brenden M Lake, Chia-ying Lee, James R Glass, and Joshua B Tenenbaum. One-shot learning of generative speech concepts. Cognitive Science Society, 2014.

  传统的计算机视觉处理图像的方法主要就是基于特征的学习和基于度量的学习,像基于度量的就是将特征编码到隐空间,早些年可能也有用核函数的方法。也有通过因果关系的,更多的是从可解释性的手段上做这些事情。


Metrics-Based Methods


  Metrics-Based Methods最主要的还是基于小样本学一个数据的表示,同时机器学习中也有一个专门的分支做这件事情:表示学习,Representation Learning

  传统的参数学习期望是给定输入X XX预测输出Y YY,通过最小化损失函数来做到这件事情,常见的损失函数有以下几类:

Siamese Network

  • 论文题目:Siamese Neural Networks for One-Shot Image Recognition

  机器学习中当只有少量数据时,获取好的特征表征是非常难做到的一件事情。作者提出孪生网络,去获取输入特征之间的相似度排序,一旦这种抽取特征的网络得到,就可以应用于新的类别数据,依据相似度去做分类。这样做的好处在于,它不仅能够在新的数据上做分类,非当前这个分布下的数据也能够很好地分类。

  孪生网络首次被提出是在下面这篇论文中:

  • Jane Bromley, James W Bentz, Leon Bottou, Isabelle Guyon, Yann LeCun, Cli Moore, Eduard Sackinger, and Roopak Shah. Signature veri cation using a siamese time delay neural network. International Journal of Pattern Recognition and Arti cial Intelligence, 7(04):669{688, 1993.

  Siamese Network通过特殊的loss函数学会去区分给定的两个输入是否相同,由两个参数一样的神经网络组成,其网络结构如下图所示:

  核心思想是将输入编码到一个隐空间,有点类似迁移学习,但是不同之处在于Siamese Network是通过contrastive loss function来做到这一点的。

  为什么要用两个参数一样的神经网络来做呢?两个网络参数一样能够学地更快,并且能够将其编码到相同的特征空间中。

  在数据预处理部分,如果两张图片是同一类,我们需要将其标为1,否者标签为0。损失函数方面,常用两类损失函数:contrastive loss functiontriplet loss function

  • Contrastive Loss Function

image.png

其中image.png表示Siamese Network,m mm表示margin,是为了使得label=1时,期望D w  的输出为0,这样就很容易将网络的权重W WW也学成0,因此加一个margin。

  • Triplet loss function

 Triplet loss function的效果一般比Contrastive Loss Function还要好一点。因为其将正例、负例样本都有考虑进loss function:


  1. Anchor ( A ): The main data point。
  2. Positive ( P ): A data point similar to Anchor。
  3. Negative ( N ):A different data point than Anchor。


 如果用距离度量的话,我们期望测试样本距离正例的距离小于负例的距离,可表示为:


image.png

其中α也是用于控制网络对任意输入输出都为0的这种情况。因此其损失函数可表示为:

image.png

Matching Network


  • 论文题目:Matching Networks for One Shot Learning

  尽管机器学习已经取得了很大的成功,但是对于给定少量数据快速学习new concepts的能力还是欠缺。作者提出了一种新的网络框架来解决这个问题。

  non-parametric的方法(e.g., nearest neighbors)可以很快学习到一种样本之间的度量方式,进而将样本分类。基于(S Roweis, G Hinton, and R Salakhutdinov. Neighbourhood component analysis. NIPS, 2004.)这篇文章,作者提出了matching network,融合parametricnon-parametric的方式。

  Matching Network将输入和标签编码到一个空间中,测试的输入也将其编码到这个空间,之后计算余弦相似度,我们就可以得到匹配信息,从而进行预测。

image.png


  其中余弦相似度c cc计算的就是support setquery set在编码空间中的相似度。那如何将数据编码呢?图像领域可以采用VGG16或者Inception这种网络结构。

  为了考虑整个support set中样本之间的关系,我们可以考虑采用bi-directional Long Short-Term Memoryg ( x i ) g(x_{i})g(xi)做优化,这样就考虑了support set中样本的context信息。

  如果在训练的时候考虑了context信息,那么由于希望编码到相同的空间,那么对于测试样本x ^,我们也希望去考虑support setcontext信息,那f ( x ^ ) 就可以表示为:


image.png

 也就是说support set样本经过g ( S ) g(S)g(S)之后得到的输出可以修改query set中样本的embedding模型(通过一个固定步数的LSTMsupport setattention,再结合emeddings ( x ^ ) \text {emeddings}(\hat{x})emeddings(x^)K KKLSTM的步数)。

Full Conditional Embedding g gg

  将Full Conditional Embeddingg gg 打开,如下图所示:

  首先将x i x_{i}xi编码成一个向量g ′ g^{\prime}g,这个可以用VGG或者Inception这种网络做个特征提取即可,之后将其经过一个Bi-LSTM,再将其三者求和。

Full Conditional Embedding f ff

  将Full Conditional Embeddingf ff 打开,如下图所示:

  先不考虑support sets ssQuery Set中的样本x ^ 先编码得到f ′ f^{\prime}f,之后经过一个LSTM得到h ^ ,再与f ′ ( x ^ )相加。


image.png


总结


  上述过程可简化为以下几步:

image.png



相关文章
|
7月前
|
数据可视化
如何用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
如何用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
|
机器学习/深度学习 开发框架 .NET
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
2683 0
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
|
2月前
|
机器学习/深度学习 Web App开发 人工智能
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》这篇论文提出了一种基于YOLOv3-Tiny的轻量级目标检测模型Micro-YOLO,通过渐进式通道剪枝和轻量级卷积层,显著减少了参数数量和计算成本,同时保持了较高的检测性能。
43 2
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
|
2月前
|
编解码 人工智能 文件存储
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
YOLOv7是一种新的实时目标检测器,通过引入可训练的免费技术包和优化的网络架构,显著提高了检测精度,同时减少了参数和计算量。该研究还提出了新的模型重参数化和标签分配策略,有效提升了模型性能。实验结果显示,YOLOv7在速度和准确性上超越了其他目标检测器。
55 0
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
|
7月前
|
数据可视化
R语言用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
R语言用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
114 10
|
7月前
|
vr&ar
R语言如何做马尔可夫转换模型markov switching model
R语言如何做马尔可夫转换模型markov switching model
|
7月前
|
数据可视化
R语言建立和可视化混合效应模型mixed effect model
R语言建立和可视化混合效应模型mixed effect model
|
7月前
|
vr&ar
R语言如何做马尔科夫转换模型markov switching model
R语言如何做马尔科夫转换模型markov switching model
|
机器学习/深度学习 算法 网络架构
少样本学习系列(二)【Model-Based Methods】
少样本学习系列(二)【Model-Based Methods】
103 0
|
机器学习/深度学习 算法
少样本学习系列(三)【Optimization-Based Methods】
少样本学习系列(三)【Optimization-Based Methods】
149 0