手把手 | OpenAI开发可拓展元学习算法Reptile,能快速学习(附代码)

简介:

在OpenAI, 我们开发了一种简易的元学习算法,称为Reptile。它通过对任务进行重复采样,利用随机梯度下降法,并将初始参数更新为在该任务上学习的最终参数。

其性能可以和MAML(model-agnostic meta-learning,由伯克利AI研究所研发的一种应用广泛的元学习算法)相媲美,操作简便且计算效率更高。

MAML元学习算法:

http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/

元学习是学习如何学习的过程。此算法接受大量各种的任务进行训练,每项任务都是一个学习问题,然后产生一个快速的学习器,并且能够通过少量的样本进行泛化。

一个深入研究的元学习问题是小样本分类(few-shot classification),其中每项任务都是一个分类问题,学习器在每个类别下只能看到1到5个输入-输出样本(input-output examples),然后就要给新输入的样本进行分类。

下面是应用了Reptile算法的单样本分类(1-shot classification)的互动演示,大家可以尝试一下。

601fe0a17f3e2b9b60a5a6cdfd9f7568e963d8fa

尝试单击“Edit All”按钮,绘制三个不同的形状或符号,然后在右侧的输入区中绘制其中一个,并查看Reptile如何对它进行分类。前三张图是标记样本,每图定义一个类别。最后一张图代表未知样本,Reptile要输出此图属于每个类别的概率。

Reptile的工作原理

像MAML一样,Reptile试图初始化神经网络的参数,以便通过新任务产生的少量数据来对网络进行微调。

但是,当MAML借助梯度下降算法的计算图来展开和区分时,Reptile只是以标准方法在每个任务中执行随机梯度下降(stochastic gradient descent, SGD)算法,并不展开计算图或者计算二阶导数。这使得Reptile比MAML需要更少的计算和内存。示例代码如下:

20b1b34904b609f3047e0f8f85deb9b2ecd46ede

 

初始化Φ,初始参数向量

对于迭代1,2,3……执行

随机抽样任务T

在任务T上执行k>1步的SGD,输入参数Φ,输出参数w

更新:Φ←Φ+ϵ(w−Φ)

结束

返回Φ

最后一步中,我们可以将Φ−W作为梯度,并将其插入像这篇论文里https://arxiv.org/abs/1412.6980Adam这样更为先进的优化器中作为替代方案。

首先令人惊讶的是,这种方法完全有效。如果k=1,这个算法就相当于 “联合训练”(joint training)——对多项任务的混合体执行SGD。虽然在某些情况下,联合训练可以学习到有用的初始化,但当零样本学习(zero-shot learning)不可能实现时(比如,当输出标签是随机排列时),联合训练就几乎无法学习得到结果。

Reptile要求k>1,也就是说,参数更新要依赖于损失函数的高阶导数实现,此时算法的表现和k=1(联合训练)时是完全不同的。

为了分析Reptile的工作原理,我们使用泰勒级数(Taylor series)来逼近参数更新。Reptile的更新将同一任务中不同小批量的梯度内积(inner product)最大化,从而提高了的泛化能力。

这一发现可能超出了元学习领域的指导意义,比如可以用来解释SGD的泛化性质。进一步分析表明,Reptile和MAML的更新过程很相近,都包括两个不同权重的项。

泰勒级数:

https://en.wikipedia.org/wiki/Taylor_series

在我们的实验中,展示了Reptile和MAML在Omniglot和Mini-ImageNet基准测试中对少量样本分类时产生相似的性能,由于更新具有较小的方差,因此Reptile也可以更快的收敛到解决方案。

Omniglot:

https://github.com/brendenlake/omniglot

Mini-ImageNet:

https://arxiv.org/abs/1606.04080

我们对Reptile的分析表明,通过不同的SGD梯度组合,可以获得大量不同的算法。在下图中,假设针对每一任务中不同小批量执行k步SGD,得出的梯度分别为g1,g2,…,gk。

下图显示了在 Omniglot 上由梯度之和作为元梯度而绘制出的学习曲线。g2对应一阶MAML,也就是原先MAML论文中提出的算法。由于方差缩减,纳入更多梯度明显会加速学习过程。需要注意的是,仅仅使用g1(对应k=1)并不会给这个任务带来改进,因为零样本学习的性能无法得到改善。

29078431915cd074ed17ddcf1b18595abee66f62

X坐标:外循环迭代次数

Y坐标:Omniglot对比5种方式的

5次分类的准确度

算法实现

我们在GitHub上提供了Reptile的算法实现,它使用TensorFlow来完成相关计算,并包含用于在Omniglot和Mini-ImageNet上小样本分类实验的代码。我们还发布了一个较小的JavaScript实现,对TensorFlow预先训练好的模型进行了微调。文章开头的互动演示也是借助JavaScript完成的。

GitHub:

https://github.com/openai/supervised-reptile

较小的JavaScript实现:

https://github.com/openai/supervised-reptile/tree/master/web

最后,展示一个小样本回归(few-shot regression)的简单示例,用以预测10(x,y)对的随机正弦波。该示例基于PyTorch实现,代码如下:

 

import numpy as np

import torch

from torch import nn, autograd as ag

import matplotlib.pyplot as plt

from copy import deepcopy



seed = 0

plot = True

innerstepsize = 0.02 # stepsize in inner SGD

innerepochs = 1 # number of epochs of each inner SGD

outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization

niterations = 30000 # number of outer updates; each iteration we sample one task and update on it



rng = np.random.RandomState(seed)

torch.manual_seed(seed)



# Define task distribution

x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points

ntrain = 10 # Size of training minibatches

def gen_task():

"Generate classification problem"

phase = rng.uniform(low=0, high=2*np.pi)

ampl = rng.uniform(0.1, 5)

f_randomsine = lambda x : np.sin(x + phase) * ampl

return f_randomsine



# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results

model = nn.Sequential(

nn.Linear(1, 64),

nn.Tanh(),

nn.Linear(64, 64),

nn.Tanh(),

nn.Linear(64, 1),

)



def totorch(x):

return ag.Variable(torch.Tensor(x))



def train_on_batch(x, y):

x = totorch(x)

y = totorch(y)

model.zero_grad()

ypred = model(x)

loss = (ypred - y).pow(2).mean()

loss.backward()

for param in model.parameters():

param.data -= innerstepsize * param.grad.data



def predict(x):

x = totorch(x)

return model(x).data.numpy()



# Choose a fixed task and minibatch for visualization

f_plot = gen_task()

xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]



# Reptile training loop

for iteration in range(niterations):

weights_before = deepcopy(model.state_dict())

# Generate task

f = gen_task()

y_all = f(x_all)

# Do SGD on this task

inds = rng.permutation(len(x_all))

for _ in range(innerepochs):

for start in range(0, len(x_all), ntrain):

mbinds = inds[start:start+ntrain]

train_on_batch(x_all[mbinds], y_all[mbinds])

# Interpolate between current weights and trained weights from this task

# I.e. (weights_before - weights_after) is the meta-gradient

weights_after = model.state_dict()

outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule

model.load_state_dict({name :

weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize

for name in weights_before})



# Periodically plot the results on a particular task and minibatch

if plot and iteration==0 or (iteration+1) % 1000 == 0:

plt.cla()

f = f_plot

weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation

plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))

for inneriter in range(32):

train_on_batch(xtrain_plot, f(xtrain_plot))

if (inneriter+1) % 8 == 0:

frac = (inneriter+1) / 32

plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))

plt.plot(x_all, f(x_all), label="true", color=(0,1,0))

lossval = np.square(predict(x_all) - f(x_all)).mean()

plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")

plt.ylim(-4,4)

plt.legend(loc="lower right")

plt.pause(0.01)

model.load_state_dict(weights_before) # restore from snapshot

print(f"-----------------------------")

print(f"iteration {iteration+1}")

print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity




原文发布时间为:2018-04-11

本文作者:文摘菌

本文来自云栖社区合作伙伴“大数据文摘”,了解相关信息可以关注“大数据文摘”。

相关文章
|
12天前
|
机器学习/深度学习 前端开发 算法
婚恋交友系统平台 相亲交友平台系统 婚恋交友系统APP 婚恋系统源码 婚恋交友平台开发流程 婚恋交友系统架构设计 婚恋交友系统前端/后端开发 婚恋交友系统匹配推荐算法优化
婚恋交友系统平台通过线上互动帮助单身男女找到合适伴侣,提供用户注册、个人资料填写、匹配推荐、实时聊天、社区互动等功能。开发流程包括需求分析、技术选型、系统架构设计、功能实现、测试优化和上线运维。匹配推荐算法优化是核心,通过用户行为数据分析和机器学习提高匹配准确性。
44 3
|
1月前
|
机器学习/深度学习 算法 数据挖掘
C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出
本文探讨了C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出。文章还介绍了C语言在知名机器学习库中的作用,以及与Python等语言结合使用的案例,展望了其未来发展的挑战与机遇。
48 1
|
1月前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
数据结构与算法系列学习之串的定义和基本操作、串的储存结构、基本操作的实现、朴素模式匹配算法、KMP算法等代码举例及图解说明;【含常见的报错问题及其对应的解决方法】你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
【EMNLP2024】基于多轮课程学习的大语言模型蒸馏算法 TAPIR
阿里云人工智能平台 PAI 与复旦大学王鹏教授团队合作,在自然语言处理顶级会议 EMNLP 2024 上发表论文《Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning》。
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习(8)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之单双链表精题详解(9)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之栈和队列精题汇总(10)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第3章之IKUN和I原达人之数据结构与算法系列学习栈与队列精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之王道第2.3章节之线性表精题汇总二(5)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
IKU达人之数据结构与算法系列学习×单双链表精题详解、数据结构、C++、排序算法、java 、动态规划 你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
6天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
124 80
|
2天前
|
机器学习/深度学习 数据采集 算法
基于PSO粒子群优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目展示了基于PSO优化的CNN-GRU-SAM网络在时间序列预测中的应用。算法通过卷积层、GRU层、自注意力机制层提取特征,结合粒子群优化提升预测准确性。完整程序运行效果无水印,提供Matlab2022a版本代码,含详细中文注释和操作视频。适用于金融市场、气象预报等领域,有效处理非线性数据,提高预测稳定性和效率。

热门文章

最新文章