少样本学习系列(二)【Model-Based Methods】

本文涉及的产品
图片翻译,图片翻译 100张
语种识别,语种识别 100万字符
文档翻译,文档翻译 1千页
简介: 少样本学习系列(二)【Model-Based Methods】

  上一节说了Metrics-Based Methods,主要是将输入编码到一个相同的特征空间,然后比较相似度。但是人类很多时候能够快速学习的原因是对以往知识、经验的利用。因此通过扩展一个记忆模块似乎也能做到少样本学习。这一节主要介绍通过模型结构的设计,来做few shot learning

Neural Turing Machines (NTMs)

  LSTM将记忆藏在隐藏节点(hidden state)中,这样就会存在很多问题,一个是计算的开销,另外一个就是记忆会被经常改动,并且是那种牵一发动全身地改变。

  而NTM由一个controller和一个memory矩阵构成,通过特定的寻址读写机制,对相关的memory进行修改,并且易于扩展:

  当给定一个输入时,controller负责依据输入对memory进行读写操作,实现记忆更新。

  但是通过特定的行和列来读取memory的话,我们就没办法对整个网络求梯度了,不能使用微分算法来更新。那这样肯定不行,要想一些办法。

  我们需要对外部的memory(存储器)进行选择性地读写。人的大脑工作的时候首先是聚焦注意力(记忆中很大一块,比如说你昨晚吃啥了,一般会聚焦到昨晚那一大块时间段),然后寻找到特定的记忆(比如有啥菜)。因此有了模糊读写(blurry read and write)的概念,通过不同的权重与内存中的所有元素进行交互

读:

  假设记忆矩阵M t stept tt 是一个拥有R 行和C 列的内存矩阵(C 代表记忆中每一行的大小)。执行读写操作的网络输出称为headscontroller输出一个attention向量,长度为R 也称之为weight vector( w t ),其中的每一个元素w t ( i ) memory matrixi ii行的weightweight通常都被归一化,用数学形式可表示为:


image.png

read head返回的就是记忆矩阵行的线性组合:

image.png

写:


  写的操作可以分为两步擦除(erasing)和添加(adding)。为了实现擦除操作,需要一个erase vectore t e_{t}et其值在0-1之间,擦除操作可表示为:


image.png

weight w t ( i ) e t 都为1时,memory被清空,当其中任意一个为0时,则不会有任何改变,这种方式也支持多个操作任意顺序的相互叠加。记忆矩阵可用一个长度为C 的向量a t 更新:

image.png

  • 寻址

  读写操作的关键就在于权重矩阵w ww,控制器产生权重矩阵可以分为以下四步:

  1. content-based addressinghead产生一个长度为C CCkey vectork t ,然后用余弦相似度度量k t 与记忆矩阵M t 的相似性:


image.png

 对记忆矩阵每一行都进行一样的操作再归一化可得content weight vector


image.png

image.png

  1. interpolation之后,head产生一个normalized shift weightings t s_{t}st,对权重进行旋转位移,比如当前的权重值关注于某一个locationmemory,经过此步就会扩展到其周围的location,使得模型对周围的memory也会做出少量的读和写操作,采用循环卷积


image.png

  1. 卷积操作之后会使得权重分布趋于均匀化,这将会导致本来集中于单个位置的焦点出现发散,这里采用锐化操作,head产生一个标量γ ≥ 1

image.png

 上述操作都是可微分的,因此可以使用微分算法对其进行优化。


Memory-augmented Neural Networks


  NMT中采用了content-based addressinglocation-based addressing,在MANN中只采用content-based addressing,因为只需要比较当前input是否和之前输入的input相似即可。



  MANN中的读取操作和NTM的读取操作非常类似,不同之处在于它只采取content-based addressing的方式,先产生一个归一化的权重向量image.png

image.png

其中K ( ) 表示余弦相似性,之后与NTM类似与记忆矩阵M t加权求和即可:

image.png


  这里采用了Least Recently Used Access (LRUA)的写入方式:


image.png

image.png

 其中m ( w t u , n ) m\left(w_{t}^{u},n\right)m(wtu,n)表示w t u w_{t}^{u}wtu中第n nn个最小的元素,只有当期够小才为1,否者为0

Meta Networks

  传统的神经网络通过stochastic gradient descent方式做更新,如果batch_size1的话,更新就会很慢。如果train一个网络去预测目标任务的网络参数的话,这样的学习起来就会很快,称之为fast weights。由此我们可以知道meta network由两部分组成:

  1. meta-learner:它所要做的事情就是获取不同task的通用的知识。可以看作是一个embeddings function,判断两个不同的数据之间的差别。
  2. base-learner:期望去学一个target task,就是最常见的学习算法,比如做个分类这样。

  开始之前定义一些术语:

  • Support set:从训练集采样得到的一些数据点( x , y ) (x,y)(x,y)
  • query set:同样也是从训练集采样得到的一些数据点( x , y ) (x,y)(x,y),作为query set。
  • Embedding functionf θ f_{\theta}fθmeta-learner的一部分,与siamese network类似,用于预测两个输入是否属于同一类。
  • Base-learner modelg ϕ g_{\phi}gϕ:就是一个需要处理完整任务的学习算法。
  • θ + \theta^{+}θ+Embedding functionf θ f_{\theta}fθfast weight,由一个LSTMF w F_{w}Fw产生。
  • ϕ + \phi^{+}ϕ+Base-learner modelg ϕ g_{\phi}gϕfast weight,由一个网络G v G_{v}Gv产生。

  可以看出slow weights( θ , ϕ ) (\theta,\phi)(θ,ϕ)构成了meta-learnersbase learners。两个不同的网络F w F_{w}FwG v G_{v}Gv生成fast weight

  meta networks的网络架构如下所示:

  可以看到meta networkbase learnermeta-learner组成,meta-learner给配了一个外部memory(external memory)。

算法

  整个训练数据被分为两部分support setS = ( x i ′ , y i ′ ) query setU = ( x i , y i ) ,我们要做的事情就是学四个网络(f ( θ ) , g ( ϕ ) , F w , G v 的参数。

  1. support set随机采样K个样本对。循环将其中每个样本1-K送入embedding functionf ( θ ) f(\theta)f(θ),并计算cross-entropy lossLembeddings
  2. 计算得到的cross-entropy lossLembeddings再经过LSTM计算image.png
  3. image.png

  然后将fastslow weight合并:


image.png

20200711203527901.png


matching networksLSTM meta-learners其实是使用了相同的策略,都有利用额外的信息,一个是contextual embeddings,一个是meta information,期望抽取出一些对于整个task比较重要的信息。

参考

相关文章
|
7月前
|
数据可视化
如何用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
如何用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
|
机器学习/深度学习 开发框架 .NET
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
2683 0
YOLOv5的Tricks | 【Trick6】学习率调整策略(One Cycle Policy、余弦退火等)
|
2月前
|
机器学习/深度学习 Web App开发 人工智能
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》这篇论文提出了一种基于YOLOv3-Tiny的轻量级目标检测模型Micro-YOLO,通过渐进式通道剪枝和轻量级卷积层,显著减少了参数数量和计算成本,同时保持了较高的检测性能。
43 2
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
|
2月前
|
编解码 人工智能 文件存储
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
YOLOv7是一种新的实时目标检测器,通过引入可训练的免费技术包和优化的网络架构,显著提高了检测精度,同时减少了参数和计算量。该研究还提出了新的模型重参数化和标签分配策略,有效提升了模型性能。实验结果显示,YOLOv7在速度和准确性上超越了其他目标检测器。
55 0
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
|
5月前
|
Python
Fama-French模型,特别是三因子模型(Fama-French Three-Factor Model)
Fama-French模型,特别是三因子模型(Fama-French Three-Factor Model)
|
5月前
|
机器学习/深度学习 数据采集 监控
算法金 | DL 骚操作扫盲,神经网络设计与选择、参数初始化与优化、学习率调整与正则化、Loss Function、Bad Gradient
**神经网络与AI学习概览** - 探讨神经网络设计,包括MLP、RNN、CNN,激活函数如ReLU,以及隐藏层设计,强调网络结构与任务匹配。 - 参数初始化与优化涉及Xavier/He初始化,权重和偏置初始化,优化算法如SGD、Adam,针对不同场景选择。 - 学习率调整与正则化,如动态学习率、L1/L2正则化、早停法和Dropout,以改善训练和泛化。
53 0
算法金 | DL 骚操作扫盲,神经网络设计与选择、参数初始化与优化、学习率调整与正则化、Loss Function、Bad Gradient
|
7月前
|
数据可视化
R语言用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
R语言用潜类别混合效应模型(Latent Class Mixed Model ,LCMM)分析老年痴呆年龄数据
114 10
|
7月前
|
vr&ar
R语言如何做马尔可夫转换模型markov switching model
R语言如何做马尔可夫转换模型markov switching model
|
7月前
|
vr&ar
R语言如何做马尔科夫转换模型markov switching model
R语言如何做马尔科夫转换模型markov switching model
|
机器学习/深度学习 数据采集 计算机视觉
少样本学习系列(一)【Metrics-Based Methods】
少样本学习系列(一)【Metrics-Based Methods】
163 0