上一节说了Metrics-Based Methods
,主要是将输入编码到一个相同的特征空间,然后比较相似度。但是人类很多时候能够快速学习的原因是对以往知识、经验的利用。因此通过扩展一个记忆模块似乎也能做到少样本学习。这一节主要介绍通过模型结构的设计,来做few shot learning
。
Neural Turing Machines (NTMs)
LSTM
将记忆藏在隐藏节点(hidden state
)中,这样就会存在很多问题,一个是计算的开销,另外一个就是记忆会被经常改动,并且是那种牵一发动全身地改变。
而NTM
由一个controller
和一个memory
矩阵构成,通过特定的寻址读写机制,对相关的memory
进行修改,并且易于扩展:
当给定一个输入时,controller
负责依据输入对memory
进行读写操作,实现记忆更新。
但是通过特定的行和列来读取memory
的话,我们就没办法对整个网络求梯度了,不能使用微分算法来更新。那这样肯定不行,要想一些办法。
我们需要对外部的memory
(存储器)进行选择性地读写。人的大脑工作的时候首先是聚焦注意力(记忆中很大一块,比如说你昨晚吃啥了,一般会聚焦到昨晚那一大块时间段),然后寻找到特定的记忆(比如有啥菜)。因此有了模糊读写(blurry read and write
)的概念,通过不同的权重与内存中的所有元素进行交互。
读:
假设记忆矩阵M t 在step
t tt 是一个拥有R 行和C 列的内存矩阵(C 代表记忆中每一行的大小)。执行读写操作的网络输出称为heads
。controller
输出一个attention
向量,长度为R 也称之为weight vector
( w t ),其中的每一个元素w t ( i ) 是memory matrix
第i ii行的weight
,weight
通常都被归一化,用数学形式可表示为:
read head
返回的就是记忆矩阵行的线性组合:
写:
写的操作可以分为两步擦除(erasing
)和添加(adding
)。为了实现擦除操作,需要一个erase vector
e t e_{t}et其值在0-1
之间,擦除操作可表示为:
当weight
w t ( i ) 和e t 都为1
时,memory
被清空,当其中任意一个为0
时,则不会有任何改变,这种方式也支持多个操作任意顺序的相互叠加。记忆矩阵可用一个长度为C 的向量a t 更新:
- 寻址
读写操作的关键就在于权重矩阵w ww,控制器产生权重矩阵可以分为以下四步:
content-based addressing
,head
产生一个长度为C CC的key vector
k t ,然后用余弦相似度度量k t 与记忆矩阵M t 的相似性:
对记忆矩阵每一行都进行一样的操作再归一化可得content weight vector
:
interpolation
之后,head
产生一个normalized shift weighting
s t s_{t}st,对权重进行旋转位移,比如当前的权重值关注于某一个location
的memory
,经过此步就会扩展到其周围的location
,使得模型对周围的memory
也会做出少量的读和写操作,采用循环卷积:
- 卷积操作之后会使得权重分布趋于均匀化,这将会导致本来集中于单个位置的焦点出现发散,这里采用锐化操作,
head
产生一个标量γ ≥ 1
上述操作都是可微分的,因此可以使用微分算法对其进行优化。
Memory-augmented Neural Networks
NMT
中采用了content-based addressing
和location-based addressing
,在MANN
中只采用content-based addressing
,因为只需要比较当前input
是否和之前输入的input
相似即可。
读
MANN
中的读取操作和NTM
的读取操作非常类似,不同之处在于它只采取content-based addressing
的方式,先产生一个归一化的权重向量
其中K ( ) 表示余弦相似性,之后与NTM
类似与记忆矩阵M t加权求和即可:
写
这里采用了Least Recently Used Access
(LRUA)的写入方式:
其中m ( w t u , n ) m\left(w_{t}^{u},n\right)m(wtu,n)表示w t u w_{t}^{u}wtu中第n nn个最小的元素,只有当期够小才为1
,否者为0
。
Meta Networks
传统的神经网络通过stochastic gradient descent
方式做更新,如果batch_size
为1
的话,更新就会很慢。如果train
一个网络去预测目标任务的网络参数的话,这样的学习起来就会很快,称之为fast weights
。由此我们可以知道meta network
由两部分组成:
meta-learner
:它所要做的事情就是获取不同task
的通用的知识。可以看作是一个embeddings function
,判断两个不同的数据之间的差别。base-learner
:期望去学一个target task
,就是最常见的学习算法,比如做个分类这样。
开始之前定义一些术语:
Support set
:从训练集采样得到的一些数据点( x , y ) (x,y)(x,y)。query set
:同样也是从训练集采样得到的一些数据点( x , y ) (x,y)(x,y),作为query set。Embedding function
f θ f_{\theta}fθ:meta-learner
的一部分,与siamese network
类似,用于预测两个输入是否属于同一类。Base-learner model
g ϕ g_{\phi}gϕ:就是一个需要处理完整任务的学习算法。- θ + \theta^{+}θ+:
Embedding function
f θ f_{\theta}fθ的fast weight
,由一个LSTM
F w F_{w}Fw产生。 - ϕ + \phi^{+}ϕ+:
Base-learner model
g ϕ g_{\phi}gϕ的fast weight
,由一个网络G v G_{v}Gv产生。
可以看出slow weights
( θ , ϕ ) (\theta,\phi)(θ,ϕ)构成了meta-learners
和base learners
。两个不同的网络F w F_{w}Fw和G v G_{v}Gv生成fast weight
。
meta networks
的网络架构如下所示:
可以看到meta network
由base learner
和meta-learner
组成,meta-learner
给配了一个外部memory
(external memory
)。
算法
整个训练数据被分为两部分support set
S = ( x i ′ , y i ′ ) 和query set
U = ( x i , y i ) ,我们要做的事情就是学四个网络(f ( θ ) , g ( ϕ ) , F w , G v 的参数。
- 从
support set
随机采样K
个样本对。循环将其中每个样本1-K送入embedding function
f ( θ ) f(\theta)f(θ),并计算cross-entropy loss
Lembeddings。 - 计算得到的
cross-entropy loss
Lembeddings再经过LSTM
计算
然后将fast
和slow weight
合并:
matching networks
和LSTM meta-learners
其实是使用了相同的策略,都有利用额外的信息,一个是contextual embeddings
,一个是meta information
,期望抽取出一些对于整个task
比较重要的信息。