再掀强化学习变革!DeepMind提出「算法蒸馏」:可探索的预训练强化学习Transformer

简介: 再掀强化学习变革!DeepMind提出「算法蒸馏」:可探索的预训练强化学习Transformer
【新智元导读】该怎么把预训练Transformer范式用到强化学习里?(文末为论文作者的ppt讲解)

在当下的序列建模任务上,Transformer可谓是最强大的神经网络架构,并且经过预训练的Transformer模型可以将prompt作为条件或上下文学习(in-context learning)适应不同的下游任务。


大型预训练Transformer模型的泛化能力已经在多个领域得到验证,如文本补全、语言理解、图像生成等等。



从去年开始,已经有相关工作证明,通过将离线强化学习(offline RL)视为一个序列预测问题,那么模型就可以从离线数据中学习策略


但目前的方法要么是从不包含学习的数据中学习策略(如通过蒸馏固定的专家策略),要么是从包含学习的数据(如智能体的重放缓冲区)中学习,但由于其context太小,以至于无法捕捉到策略提升。



DeepMind的研究人员通过观察发现,原则上强化学习算法训练中学习的顺序性(sequential nature)可以将强化学习过程本身建模为一个「因果序列预测问题」


具体来说,如果一个Transformer的上下文足够长到可以包含由于学习更新而产生的策略改进,那它应该不仅能够表示一个固定的策略,而且能够通过关注之前episodes的状态、行动和奖励表示为一个策略提升算子(policy improvement operator)。


这也提供了一种技术上的可行性,即任何RL算法都可以通过模仿学习蒸馏成一个足够强大的序列模型,并将其转化为一个in-context RL算法。


基于此,DeepMind提出了算法蒸馏(Algorithm Distillation, AD) ,通过建立因果序列模型将强化学习算法提取到神经网络中。


论文链接:https://arxiv.org/pdf/2210.14215.pdf


算法蒸馏将学习强化学习视为一个跨episode的序列预测问题,通过源RL算法生成一个学习历史数据集,然后根据学习历史作为上下文,通过自回归预测行为来训练因果Transformer。


与蒸馏后学习(post-learning)或专家序列的序列策略预测结构不同,AD能够在不更新其网络参数的情况下完全在上下文中改进其策略。


  • Transfomer收集自己的数据,并在新任务上最大化奖励;
  • 无需prompting或微调;
  • 在权重冻结的情况下,Transformer可探索、利用和最大化上下文的返回(return)!诸如Gato类的专家蒸馏(Expert Distillation)方法无法探索,也无法最大化返回。

实验结果证明了AD可以在稀疏奖励、组合任务结构和基于像素观察的各种环境中进行强化学习,并且AD学习的数据效率(data-efficient)比生成源数据的RL算法更高。


AD也是第一个通过对具有模仿损失(imitation loss)的离线数据进行序列建模来展示in-context强化学习的方法。


算法蒸馏


2021年,有研究人员首先发现Transformer可以通过模仿学习从离线RL数据中学习单任务策略,随后又被扩展为可以在同域和跨域设置中提取多任务策略。


这些工作为提取通用的多任务策略提出了一个很有前景的范式:首先收集大量不同的环境互动数据集,然后通过序列建模从数据中提取一个策略。


把通过模仿学习从离线RL数据中学习策略的方法也称之为离线策略蒸馏,或者简称为策略蒸馏(Policy Distillation, PD)


尽管PD的思路非常简单,并且十分易于扩展,但PD有一个重大的缺陷:生成的策略并没有从与环境的额外互动中得到提升。


例如,MultiGame Decision Transformer(MGDT)学习了一个可以玩大量Atari游戏的返回条件策略,而Gato通过上下文推断任务,学习了一个在不同环境中解决任务的策略,但这两种方法都不能通过试错来改进其策略。


MGDT通过微调模型的权重使变压器适应新的任务,而Gato则需要专家的示范提示才能适应新的任务。


简而言之,Policy Distillation方法学习政策而非强化学习算法。


研究人员假设Policy Distillation不能通过试错来改进的原因是,它在没有显示学习进展的数据上进行训练。


算法蒸馏(AD)通过优化一个RL算法的学习历史上的因果序列预测损失来学习内涵式策略改进算子的方法。



AD包括两个组成部分

1、通过保存一个RL算法在许多单独任务上的训练历史,生成一个大型的多任务数据集;

2、将Transformer使用前面的学习历史作为其背景对行动进行因果建模。


由于策略在源RL算法的整个训练过程中不断改进,AD必须得学习如何改进算子,才能准确模拟训练历史中任何给定点的行动。


最重要的是,Transformer的上下文大小必须足够大(即跨周期),以捕捉训练数据的改进。



在实验部分,为了探索AD在in-context RL能力上的优势,研究人员把重点放在预训练后不能通过zero-shot 泛化解决的环境上,即要求每个环境支持多种任务,且模型无法轻易地从观察中推断出任务的解决方案。同时episodes需要足够短以便可以训练跨episode的因果Transformer。



在四个环境Adversarial Bandit、Dark Room、Dark Key-to-Door、DMLab Watermaze的实验结果中可以看到,通过模仿基于梯度的RL算法,使用具有足够大上下文的因果Transformer,AD可以完全在上下文中强化学习新任务。



AD能够进行in-context中的探索、时间上的信用分配和泛化,AD学习的算法比产生Transformer训练的源数据的算法更有数据效率。


PPT讲解


为了方便论文理解,论文的一作Michael Laskin在推特上发表了一份ppt讲解。



算法蒸馏的实验表明,Transformer可以通过试错自主改善模型,并且不用更新权重,无需提示、也无需微调。单个Transformer可以收集自己的数据,并在新任务上将奖励最大化。


尽管目前已经有很多成功的模型展示了Transformer如何在上下文中学习,但Transformer还没有被证明可以在上下文中强化学习。


为了适应新的任务,开发者要么需要手动指定一个提示,要么需要调整模型。

如果Transformer可以适应强化学习,做到开箱即用岂不美哉?


但Decision Transformers或者Gato只能从离线数据中学习策略,无法通过反复实验自动改进。



使用算法蒸馏(AD)的预训练方法生成的Transformer可以在上下文中强化学习。



首先训练一个强化学习算法的多个副本来解决不同的任务和保存学习历史。



一旦收集完学习历史的数据集,就可以训练一个Transformer来预测之前的学习历史的行动。


由于策略在历史上有所改进,因此准确地预测行动将会迫使Transformer对策略提升进行建模。



整个过程就是这么简单,Transformer只是通过模仿动作来训练,没有像常见的强化学习模型所用的Q值,没有长的操作-动作-奖励序列,也没有像 DTs 那样的返回条件。


在上下文中,强化学习没有额外开销,然后通过观察 AD 是否能最大化新任务的奖励来评估模型。


Transformer探索、利用、并最大化返回在上下文时,它的权重是冻结的!

另一方面,专家蒸馏(最类似于Gato)不能探索,也不能最大化回报。



AD 可以提取任何 RL 算法,研究人员尝试了 UCB、DQNA2C,一个有趣的发现是,在上下文 RL 算法学习中,AD更有数据效率。



用户还可以输入prompt和次优的demo,模型会自动进行策略提升,直到获得最优解!


而专家蒸馏ED只能维持次优的demo表现。



只有当Transformer的上下文足够长,跨越多个episode时,上下文RL才会出现。


AD需要一个足够长的历史,以进行有效的模型改进和identify任务。



通过实验,研究人员得出以下结论:

  • Transformer可以在上下文中进行 RL
  • 带 AD 的上下文 RL 算法比基于梯度的源 RL 算法更有效
  • AD提升了次优策略
  • in-context强化学习产生于长上下文的模仿学习


参考资料:https://www.reddit.com/r/MachineLearning/comments/ye578r/r_incontext_reinforcement_learning_with_algorithm/https://twitter.com/MishaLaskin/status/1585265485314129926

相关文章
|
2月前
|
机器学习/深度学习 算法 Python
【Python强化学习】时序差分法Sarsa算法和Qlearning算法在冰湖问题中实战(附源码)
【Python强化学习】时序差分法Sarsa算法和Qlearning算法在冰湖问题中实战(附源码)
59 1
|
3天前
|
机器学习/深度学习 人工智能 自然语言处理
算法金 | Transformer,一个神奇的算法模型!!
**Transformer 模型的核心是自注意力机制,它改善了长序列理解,让每个单词能“注意”到其他单词。自注意力通过查询、键和值向量计算注意力得分,多头注意力允许并行处理多种关系。残差连接和层归一化加速训练并提升模型稳定性。该机制广泛应用于NLP和图像处理,如机器翻译和图像分类。通过预训练模型微调和正则化技术可进一步优化。**
26 1
算法金 | Transformer,一个神奇的算法模型!!
|
7天前
|
机器学习/深度学习 分布式计算 算法
在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)
【6月更文挑战第28天】在机器学习项目中,选择算法涉及问题类型识别(如回归、分类、聚类、强化学习)、数据规模与特性(大数据可能适合分布式算法或深度学习)、性能需求(准确性、速度、可解释性)、资源限制(计算与内存)、领域知识应用以及实验验证(交叉验证、模型比较)。迭代过程包括数据探索、模型构建、评估和优化,结合业务需求进行决策。
14 0
|
2月前
|
机器学习/深度学习 敏捷开发 算法
算法人生(1):从“强化学习”看如何“战胜拖延”
算法人生系列探讨如何将强化学习理念应用于个人成长。强化学习是一种机器学习方法,通过奖励和惩罚促使智能体优化行为策略。它包括识别环境、小步快跑、强正避负和持续调优四个步骤。将此应用于克服拖延,首先要识别拖延原因并分解目标,其次实施奖惩机制,如延迟满足和替换刺激物,最后持续调整策略以最大化效果。通过这种动态迭代过程,我们可以更好地理解和应对生活中的拖延问题。
|
2月前
|
机器学习/深度学习 算法 Python
使用Python实现强化学习算法
使用Python实现强化学习算法
29 1
使用Python实现强化学习算法
|
2月前
|
机器学习/深度学习 算法
算法人生(2):从“强化学习”看如何“活在当下”
本文探讨了强化学习的原理及其在个人生活中的启示。强化学习强调智能体在动态环境中通过与环境交互学习最优策略,不断迭代优化。这种思想类似于“活在当下”的哲学,要求人们专注于当前状态和决策,不过分依赖历史经验或担忧未来。活在当下意味着全情投入每一刻,不被过去或未来牵绊。通过减少执着,提高觉察力和静心练习,我们可以更好地活在当下,同时兼顾历史经验和未来规划。文章建议实践静心、时间管理和接纳每个瞬间,以实现更低焦虑、更高生活质量的生活艺术。
|
2月前
|
机器学习/深度学习 存储 算法
数据结构与算法 动态规划(启发式搜索、遗传算法、强化学习待完善)
数据结构与算法 动态规划(启发式搜索、遗传算法、强化学习待完善)
29 1
|
2月前
|
机器学习/深度学习 存储 人工智能
神经网络算法 —— 一文搞懂Transformer !!
神经网络算法 —— 一文搞懂Transformer !!
232 0
|
2月前
|
机器学习/深度学习 人工智能 自然语言处理
|
2月前
|
机器学习/深度学习 算法 算法框架/工具
OpenAI Gym 中级教程——深入强化学习算法
OpenAI Gym 中级教程——深入强化学习算法
210 6