再掀强化学习变革!DeepMind提出「算法蒸馏」:可探索的预训练强化学习Transformer

简介: 再掀强化学习变革!DeepMind提出「算法蒸馏」:可探索的预训练强化学习Transformer
【新智元导读】该怎么把预训练Transformer范式用到强化学习里?(文末为论文作者的ppt讲解)

在当下的序列建模任务上,Transformer可谓是最强大的神经网络架构,并且经过预训练的Transformer模型可以将prompt作为条件或上下文学习(in-context learning)适应不同的下游任务。


大型预训练Transformer模型的泛化能力已经在多个领域得到验证,如文本补全、语言理解、图像生成等等。



从去年开始,已经有相关工作证明,通过将离线强化学习(offline RL)视为一个序列预测问题,那么模型就可以从离线数据中学习策略


但目前的方法要么是从不包含学习的数据中学习策略(如通过蒸馏固定的专家策略),要么是从包含学习的数据(如智能体的重放缓冲区)中学习,但由于其context太小,以至于无法捕捉到策略提升。



DeepMind的研究人员通过观察发现,原则上强化学习算法训练中学习的顺序性(sequential nature)可以将强化学习过程本身建模为一个「因果序列预测问题」


具体来说,如果一个Transformer的上下文足够长到可以包含由于学习更新而产生的策略改进,那它应该不仅能够表示一个固定的策略,而且能够通过关注之前episodes的状态、行动和奖励表示为一个策略提升算子(policy improvement operator)。


这也提供了一种技术上的可行性,即任何RL算法都可以通过模仿学习蒸馏成一个足够强大的序列模型,并将其转化为一个in-context RL算法。


基于此,DeepMind提出了算法蒸馏(Algorithm Distillation, AD) ,通过建立因果序列模型将强化学习算法提取到神经网络中。


论文链接:https://arxiv.org/pdf/2210.14215.pdf


算法蒸馏将学习强化学习视为一个跨episode的序列预测问题,通过源RL算法生成一个学习历史数据集,然后根据学习历史作为上下文,通过自回归预测行为来训练因果Transformer。


与蒸馏后学习(post-learning)或专家序列的序列策略预测结构不同,AD能够在不更新其网络参数的情况下完全在上下文中改进其策略。


  • Transfomer收集自己的数据,并在新任务上最大化奖励;
  • 无需prompting或微调;
  • 在权重冻结的情况下,Transformer可探索、利用和最大化上下文的返回(return)!诸如Gato类的专家蒸馏(Expert Distillation)方法无法探索,也无法最大化返回。

实验结果证明了AD可以在稀疏奖励、组合任务结构和基于像素观察的各种环境中进行强化学习,并且AD学习的数据效率(data-efficient)比生成源数据的RL算法更高。


AD也是第一个通过对具有模仿损失(imitation loss)的离线数据进行序列建模来展示in-context强化学习的方法。


算法蒸馏


2021年,有研究人员首先发现Transformer可以通过模仿学习从离线RL数据中学习单任务策略,随后又被扩展为可以在同域和跨域设置中提取多任务策略。


这些工作为提取通用的多任务策略提出了一个很有前景的范式:首先收集大量不同的环境互动数据集,然后通过序列建模从数据中提取一个策略。


把通过模仿学习从离线RL数据中学习策略的方法也称之为离线策略蒸馏,或者简称为策略蒸馏(Policy Distillation, PD)


尽管PD的思路非常简单,并且十分易于扩展,但PD有一个重大的缺陷:生成的策略并没有从与环境的额外互动中得到提升。


例如,MultiGame Decision Transformer(MGDT)学习了一个可以玩大量Atari游戏的返回条件策略,而Gato通过上下文推断任务,学习了一个在不同环境中解决任务的策略,但这两种方法都不能通过试错来改进其策略。


MGDT通过微调模型的权重使变压器适应新的任务,而Gato则需要专家的示范提示才能适应新的任务。


简而言之,Policy Distillation方法学习政策而非强化学习算法。


研究人员假设Policy Distillation不能通过试错来改进的原因是,它在没有显示学习进展的数据上进行训练。


算法蒸馏(AD)通过优化一个RL算法的学习历史上的因果序列预测损失来学习内涵式策略改进算子的方法。



AD包括两个组成部分

1、通过保存一个RL算法在许多单独任务上的训练历史,生成一个大型的多任务数据集;

2、将Transformer使用前面的学习历史作为其背景对行动进行因果建模。


由于策略在源RL算法的整个训练过程中不断改进,AD必须得学习如何改进算子,才能准确模拟训练历史中任何给定点的行动。


最重要的是,Transformer的上下文大小必须足够大(即跨周期),以捕捉训练数据的改进。



在实验部分,为了探索AD在in-context RL能力上的优势,研究人员把重点放在预训练后不能通过zero-shot 泛化解决的环境上,即要求每个环境支持多种任务,且模型无法轻易地从观察中推断出任务的解决方案。同时episodes需要足够短以便可以训练跨episode的因果Transformer。



在四个环境Adversarial Bandit、Dark Room、Dark Key-to-Door、DMLab Watermaze的实验结果中可以看到,通过模仿基于梯度的RL算法,使用具有足够大上下文的因果Transformer,AD可以完全在上下文中强化学习新任务。



AD能够进行in-context中的探索、时间上的信用分配和泛化,AD学习的算法比产生Transformer训练的源数据的算法更有数据效率。


PPT讲解


为了方便论文理解,论文的一作Michael Laskin在推特上发表了一份ppt讲解。



算法蒸馏的实验表明,Transformer可以通过试错自主改善模型,并且不用更新权重,无需提示、也无需微调。单个Transformer可以收集自己的数据,并在新任务上将奖励最大化。


尽管目前已经有很多成功的模型展示了Transformer如何在上下文中学习,但Transformer还没有被证明可以在上下文中强化学习。


为了适应新的任务,开发者要么需要手动指定一个提示,要么需要调整模型。

如果Transformer可以适应强化学习,做到开箱即用岂不美哉?


但Decision Transformers或者Gato只能从离线数据中学习策略,无法通过反复实验自动改进。



使用算法蒸馏(AD)的预训练方法生成的Transformer可以在上下文中强化学习。



首先训练一个强化学习算法的多个副本来解决不同的任务和保存学习历史。



一旦收集完学习历史的数据集,就可以训练一个Transformer来预测之前的学习历史的行动。


由于策略在历史上有所改进,因此准确地预测行动将会迫使Transformer对策略提升进行建模。



整个过程就是这么简单,Transformer只是通过模仿动作来训练,没有像常见的强化学习模型所用的Q值,没有长的操作-动作-奖励序列,也没有像 DTs 那样的返回条件。


在上下文中,强化学习没有额外开销,然后通过观察 AD 是否能最大化新任务的奖励来评估模型。


Transformer探索、利用、并最大化返回在上下文时,它的权重是冻结的!

另一方面,专家蒸馏(最类似于Gato)不能探索,也不能最大化回报。



AD 可以提取任何 RL 算法,研究人员尝试了 UCB、DQNA2C,一个有趣的发现是,在上下文 RL 算法学习中,AD更有数据效率。



用户还可以输入prompt和次优的demo,模型会自动进行策略提升,直到获得最优解!


而专家蒸馏ED只能维持次优的demo表现。



只有当Transformer的上下文足够长,跨越多个episode时,上下文RL才会出现。


AD需要一个足够长的历史,以进行有效的模型改进和identify任务。



通过实验,研究人员得出以下结论:

  • Transformer可以在上下文中进行 RL
  • 带 AD 的上下文 RL 算法比基于梯度的源 RL 算法更有效
  • AD提升了次优策略
  • in-context强化学习产生于长上下文的模仿学习


参考资料:https://www.reddit.com/r/MachineLearning/comments/ye578r/r_incontext_reinforcement_learning_with_algorithm/https://twitter.com/MishaLaskin/status/1585265485314129926

相关文章
|
11天前
|
机器学习/深度学习 算法 数据可视化
从另一个视角看Transformer:注意力机制就是可微分的k-NN算法
注意力机制可理解为一种“软k-NN”:查询向量通过缩放点积计算与各键的相似度,softmax归一化为权重,对值向量加权平均。1/√d缩放防止高维饱和,掩码控制信息流动(如因果、填充)。不同相似度函数(点积、余弦、RBF)对应不同归纳偏置,多头则在多个子空间并行该过程。
165 6
|
5月前
|
机器学习/深度学习 数据采集 算法
智能限速算法:基于强化学习的动态请求间隔控制
本文分享了通过强化学习解决抖音爬虫限速问题的技术实践。针对固定速率请求易被封禁的问题,引入基于DQN的动态请求间隔控制算法,智能调整请求间隔以平衡效率与稳定性。文中详细描述了真实经历、问题分析、技术突破及代码实现,包括代理配置、状态设计与奖励机制,并反思成长,提出未来优化方向。此方法具通用性,适用于多种动态节奏控制场景。
149 6
智能限速算法:基于强化学习的动态请求间隔控制
|
3月前
|
机器学习/深度学习 存储 算法
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
本文系统研究了多智能体强化学习的算法性能与评估框架,选用井字棋和连珠四子作为基准环境,对比分析Q-learning、蒙特卡洛、Sarsa等表格方法在对抗场景中的表现。实验表明,表格方法在小规模状态空间(如井字棋)中可有效学习策略,但在大规模状态空间(如连珠四子)中因泛化能力不足而失效,揭示了向函数逼近技术演进的必要性。研究构建了标准化评估流程,明确了不同算法的适用边界,为理解强化学习的可扩展性问题提供了实证支持与理论参考。
134 0
强化学习算法基准测试:6种算法在多智能体环境中的表现实测
|
6月前
|
机器学习/深度学习 存储 算法
18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
本文系统讲解从基本强化学习方法到高级技术(如PPO、A3C、PlaNet等)的实现原理与编码过程,旨在通过理论结合代码的方式,构建对强化学习算法的全面理解。
860 10
18个常用的强化学习算法整理:从基础方法到高级模型的理论技术与代码实现
|
4月前
|
机器学习/深度学习 算法 数据可视化
基于Qlearning强化学习的机器人迷宫路线搜索算法matlab仿真
本内容展示了基于Q-learning算法的机器人迷宫路径搜索仿真及其实现过程。通过Matlab2022a进行仿真,结果以图形形式呈现,无水印(附图1-4)。算法理论部分介绍了Q-learning的核心概念,包括智能体、环境、状态、动作和奖励,以及Q表的构建与更新方法。具体实现中,将迷宫抽象为二维网格世界,定义起点和终点,利用Q-learning训练机器人找到最优路径。核心程序代码实现了多轮训练、累计奖励值与Q值的可视化,并展示了机器人从起点到终点的路径规划过程。
129 0
|
7月前
|
机器学习/深度学习 算法 机器人
强化学习:时间差分(TD)(SARSA算法和Q-Learning算法)(看不懂算我输专栏)——手把手教你入门强化学习(六)
本文介绍了时间差分法(TD)中的两种经典算法:SARSA和Q-Learning。二者均为无模型强化学习方法,通过与环境交互估算动作价值函数。SARSA是On-Policy算法,采用ε-greedy策略进行动作选择和评估;而Q-Learning为Off-Policy算法,评估时选取下一状态中估值最大的动作。相比动态规划和蒙特卡洛方法,TD算法结合了自举更新与样本更新的优势,实现边行动边学习。文章通过生动的例子解释了两者的差异,并提供了伪代码帮助理解。
434 2
|
9月前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
2450 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
9月前
|
机器学习/深度学习 人工智能 算法
Transformer打破三十年数学猜想!Meta研究者用AI给出反例,算法杀手攻克数学难题
《PatternBoost: Constructions in Mathematics with a Little Help from AI》提出了一种结合传统搜索算法和Transformer神经网络的PatternBoost算法,通过局部搜索和全局优化交替进行,成功应用于组合数学问题。该算法在图论中的Ramsey数研究中找到了更小的反例,推翻了一个30年的猜想,展示了AI在数学研究中的巨大潜力,但也面临可解释性和通用性的挑战。论文地址:https://arxiv.org/abs/2411.00566
198 13
|
8月前
|
机器学习/深度学习 自然语言处理 算法
生成式 AI 大语言模型(LLMs)核心算法及源码解析:预训练篇
生成式 AI 大语言模型(LLMs)核心算法及源码解析:预训练篇
1372 0
|
10月前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
1117 30

热门文章

最新文章