从Transformer到扩散模型,一文了解基于序列建模的强化学习方法

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: 从Transformer到扩散模型,一文了解基于序列建模的强化学习方法

本文将简单谈谈基于序列建模的强化学习方法。

大规模生成模型在近两年为自然语言处理甚至计算机视觉带来的巨大的突破。最近这股风潮也影响到了强化学习,尤其是离线强化学习(offline RL),诸如 Decision Transformer (DT)[1], Trajectory Transformer(TT)[2], Gato[3], Diffuser[4]这样的方法,将强化学习的数据(包括状态,动作,奖励和 return-to-go)当成了一串去结构化的序列数据,并将建模这些序列数据作为学习的核心任务。这些模型都可以用监督或自监督学习的方法来训练,避免了传统强化学习中比较不稳定的梯度信号。即便使用复杂的策略提升 (policy improvement) 和估值 (value estimation) 方法,它们在离线强化学习中都展现了非常好的表现。


本篇将简单谈谈这些基于序列建模的强化学习方法,下篇笔者将介绍我们新提出的,Trajectory Autoencoding Planner(TAP),一种用 Vector Quantised Variational AutoEncoder (VQ-VAE)进行序列建模并进行高效的隐动作空间(latent action space)内规划的方法。

Transformer 与强化学习
Transformer 架构 [5] 于 2017 年提出之后慢慢引发了自然语言处理的革命,后续的 BERT 和 GPT-3 逐渐将自监督 + Transformer 这个组合不断推向新的高度,在自然语言处理领域不断涌现出少样本 (few-shot) 学习等性质的同时,也开始向诸如计算机视觉的领域不断扩散[6][7]。

然而对于强化学习来说,这个进程似乎在 2021 年之前都不是特别明显。在 2018 年,多头注意力机制也被引入强化学习 [8],这类工作基本都是应用在类似半符号化(sub-symbolic) 的领域尝试解决强化学习泛化的问题。之后这类尝试就一直处于一种不温不火的状态。根据笔者个人的体验,实际上 Transformer 在强化学习上也并没有展现出稳定的压倒性的优势,而且还很难训练。在 20 年我们的一个用 Relational GCN 做强化学习的工作中 [9],我们其实也在背后试过 Transformer,但是基本比传统结构(类似 CNN)差得多,很难稳定训练得到一个能用的 policy。为什么 Transformer 和传统在线强化学习(online RL)的相性比较差还是个开放问题,比如 Melo[10] 解释说是因为传统的 Transformer 的参数初始化不适合强化学习,在此我就不多做讨论了。

2021 年年中,Decision Transformer (DT)和 Trajectory Transformer(TT)的发表掀起了 Transformer 在 RL 上应用的新大潮。这两个工作的思路其实很直接:如果 Transformer 和在线强化学习的算法不是很搭,那要不干脆把强化学习当成一个自监督学习的任务?趁着离线强化学习这个概念也很火热,这两个工作都将自己的主要目标任务锁定为建模离线数据集(offline dataset),然后再将这个序列模型用作控制和决策。

对于强化学习来说,所谓序列就是由状态(state) s ,动作(action) ,奖励(reward) r 和价值(value) v 构成的轨迹(trajectory) 其中价值目前一般是被用 return-to-go 来替代,可以被看成是一种蒙特卡洛估计(Monte Carlo estimation)。离线数据集就由这一条条轨迹构成。轨迹的产生和环境的动力学模型(dynamics)以及行为策略(behaviour policy)有关。而所谓序列建模,就是要建模产生产生这个序列的概率分布(distribution),或者严格上说是其中的一些条件概率。


Decision Transformer
DT 的做法是建模一个从过往数据和价值到动作的映射 (return-conditioned policy),也就是建模了一个动作的条件概率的数学期望这种思路很类似于 Upside Down RL[11],不过很有可能背后的直接动机是模仿 GPT2/3 那种根据提示词(prompt) 完成下游任务的做法。这种做法的一个问题是要决定什么是最好的目标价值没有一个比较系统化的方法。然而 DT 的作者们发现哪怕将目标价值设为整个数据集中的最高 return,最后 DT 的表现也可以很不错。

Decision Transformer, Figure 1

对于有强化学习背景的人来说,DT 这样的方法能取得很强的表现是非常反直觉的。如果说 DQN,策略梯度(Policy Gradient)这类方法还可以只把神经网络当成一个能做插值泛化的拟合函数,强化学习中的策略提升、估值仍然是构造策略的核心的话。DT 就完全可以说是以神经网络为核心的了,背后它如何把一个可能不切实际的高目标价值联系到一个合适的动作的整个过程都完全是黑箱。DT 的成功可以说从强化学习的角度来看有些没有道理,不过我觉得这也正是这种实证研究的魅力所在。笔者认为神经网络,或者说 Transformer 的泛化能力可能超乎整个 RL 社群之前的预期。

DT 在所有序列建模方法中也是非常简单的,几乎所有强化学习的核心问题都在 Transformer 内部被解决了。这种简单性是它目前最受青睐的原因之一。不过它黑盒的性质也导致我们在算法设计层面上失去了很多抓手,传统的强化学习中的一些成果很难被融入其中。而这些成果的有效性已经在一些超大规模的实验(如 AlphaGo, AlphaStar, VPT)中被反复证实了。

Trajectory Transformer
TT 的做法则更类似传统的基于模型的强化学习 (model-based RL) 的规划(planning)方法。在建模方面,它将整个序列中的元素都离散化,然后用了 GPT-2 那样的离散的自回归(auto-regressive)方式来建模整个离线数据集。这使得它能够建模任意给定除去 return-to-go 的序列的后续 因为建模了后续序列的分布,TT 其实就成为了一个序列生成模型。通过在生成的序列中寻找拥有更好的估值(value estimation)的序列,TT 就可以输出一个“最优规划”。至于寻找最优序列的方法,TT 用了一种自然语言常用的方法:beam search 的一种变种。基本上就是永远保留已经展开的序列中最优的一部分序列,然后在它们的基础上寻找下一步的最优序列集

从强化学习的角度来说,TT 没有 DT 那么离经叛道。它的有趣之处在于(和 DT 一样)完全抛弃了原本强化学习中马尔可夫决策过程(Markov Decision Process)的因果图结构。之前的基于模型的方法比如,PETS, world model, dramerv2 等,都会遵循马尔可夫过程(或者隐式马尔可夫)中策略函数、转移函数、奖励函数等的定义,也就是状态分布的条件是上一步的状态,而动作、奖励、价值都由当前的状态决定。整个强化学习社区一般相信这样能提高样本效率,不过这样的图结构其实也可能是一种制约。自然语言领域从 RNN 到 Transformer 以及计算机视觉领域 CNN 到 Transformer 的转变其实都体现了:随着数据增加,让网络自己学习图结构更有利于获得表现更好的模型。

DreamerV2, Figure 3
由于 TT 基本上把所有序列预测的任务都交给了 Transformer,Transformer 就能更加灵活地从数据中学习出更好的图结构。如下图,TT 建模出的行为策略根据不同的任务和数据集展现出不同的图结构。图左对应了传统的马尔可夫策略,图右对应了一种动作滑动平均的策略。

Trajectory Transformer, Figure 4

Transformer 强大的序列建模能力带来了更高的长序列建模精度,下图展示了 TT 在 100 步以上的预测仍然保持了高精度,而遵循马尔可夫性质的单步预测模型很快因为预测误差叠加的问题崩溃了。

Trajectory Transformer, Figure 2
TT 虽然在具体建模和预测方面和传统方法有所不同,它提供的预测能力还是给未来融入强化学习的其它成果留出了很好的抓手。然而 TT 在预测速度上有一个重要问题:因为需要建模整个序列的分布,它将序列中所有的元素按照维度进行离散化,这也就是说一个 100 维的状态就需要占用序列中的 100 个位置,这使得被建模的序列的实际长度很容易变得特别长。而对于 Transformer,它关于序列长度 N 的运算复杂度是 ,这使得从 TT 中采样一个对未来的预测变得非常昂贵。哪怕 100 维以下的任务 TT 也需要数秒甚至数十秒来进行一步决策,这样的模型很难被投入实时的机器人控制或者在线学习之中。

Gato
Gato 是 Deepmind 发表的“通才模型”,其实就是一个跨模态多任务生成模型。用同一个 Transformer 它可以完成从自然语言问答,图片描述,玩电子游戏到机器人控制等各类工作。在针对连续控制(continous control)的建模方面 Gato 的做法基本上和 TT 类似。只不过 Gato 严格意义并不是在做强化学习,它只是建模了专家策略产生的序列数据,然后在行动时它只需要采样下一个动作,其实是对专家策略的一种模仿。

Gato Blog

其它序列生成模型:扩散模型
最近在图片生成领域扩散模型(Diffusion Model)可以说是大红大紫,DALLE-2 和 Stable Diffusion 都是基于它进行图片生成的。Diffuser 就将这个方法也运用到了离线强化学习当中,其思路和 TT 类似,先建模序列的条件分布,然后根据当前状态采样未来可能的序列。

Diffuser 相比 TT 又拥有了更强的灵活性:它可以在设定起点和终点的情形下让模型填充出中间的路径,这样就能实现目标驱动(而非最大化奖励函数)的控制。它还可以将多个目标和先验的达成目标的条件混合起来帮助模型完成任务。

Diffuser Figure 1
Diffuser 相对于传统的强化学习模型也是比较颠覆的,它生成的计划不是在时间轴上逐步展开,而是从整个序列意义上的模糊变得逐渐精确。扩散模型本身的进一步研究也是计算机视觉中的一个火热的话题,在其模型本身上很可能未来几年也会有突破。

不过扩散模型本身目前相比于其它生成模型有一个特别的缺陷,那就是它的生成速度相比于其它生成模型会更慢。很多相关领域的专家认为这一点可能在未来几年内会被缓解。不过数秒的生成时间目前对于强化学习需要实时控制的情景来说是很难接受的。Diffuser 提出了能够提升生成速度的方法:从上一步的计划开始增加少量噪音来重新生成下一步的计划,不过这样做会一定程度上降低模型的表现。

参考

  1. Decision Transformer: Reinforcement Learning via Sequence Modeling https://arxiv.org/abs/2106.01345
  2. Offline Reinforcement Learning as One Big Sequence Modeling Problem https://arxiv.org/abs/2106.02039
  3. A Generalist Agent https://arxiv.org/abs/2205.06175
  4. Planning with Diffusion for Flexible Behavior Synthesis https://arxiv.org/abs/2205.09991
  5. Attention Is All You Need https://arxiv.org/abs/1706.03762
  6. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale https://arxiv.org/abs/2010.11929
  7. Masked Autoencoders Are Scalable Vision Learners https://arxiv.org/abs/2111.06377
  8. Relational Deep Reinforcement Learning https://arxiv.org/abs/1806.01830
  9. Grid-to-Graph: Flexible Spatial Relational Inductive Biases for Reinforcement Learning https://arxiv.org/abs/2102.04220
  10. Transformers are Meta-Reinforcement Learners https://arxiv.org/abs/2206.06614
  11. Reinforcement Learning Upside Down: Don't Predict Rewards -- Just Map Them to Actions https://arxiv.org/abs/1912.02875


相关文章
|
9天前
|
Web App开发 开发框架 .NET
使用 Playwright MCP 实现小红书全自动发布的完整流程
告别小红书自动化中的登录难题!本文手把手教你使用Playwright MCP,通过复用已登录浏览器会话,实现图文发布全程无人值守。无需应对验证码,避免登录态失效,真正实现稳定、高效的自动化操作,助你轻松提升运营效率。
|
9月前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
2440 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
12月前
|
算法 C++
【算法解题思想】动态规划+深度优先搜索(C/C++)
【算法解题思想】动态规划+深度优先搜索(C/C++)
|
Java Docker 微服务
微服务架构的概念、特点以及如何在Java Web开发中实现微服务。
微服务架构的概念、特点以及如何在Java Web开发中实现微服务。
234 1
|
机器学习/深度学习 自然语言处理 自动驾驶
深度学习中的自监督学习:突破数据标注瓶颈的新路径
随着深度学习在各个领域的广泛应用,数据标注的高成本和耗时逐渐成为限制其发展的瓶颈。自监督学习作为一种无需大量人工标注数据的方法,正在引起越来越多的关注。本文探讨了自监督学习的基本原理、经典方法及其在实际应用中的优势与挑战。
567 27
|
存储 JavaScript 前端开发
|
Python
解决Pycharm安装后无法导入库的问题
解决Pycharm导入库问题:进入Settings,选择Project的`Python Interpreter`,点击Add Interpreter。删除`.venv`文件夹内容,然后关闭并重启Pycharm以初始化新环境,现在可以正常导入库了。
546 1
解决Pycharm安装后无法导入库的问题
|
存储 机器学习/深度学习 自然语言处理
Transformer 自然语言处理(二)
Transformer 自然语言处理(二)
482 0
Transformer 自然语言处理(二)
|
机器学习/深度学习 数据采集 人工智能
清北联合出品!一篇Survey整明白「Transformer+强化学习」的来龙去脉
清北联合出品!一篇Survey整明白「Transformer+强化学习」的来龙去脉
431 0
|
存储 XML JavaScript
圣诞节到了,用代码给对象写一颗圣诞树吧
JS是JavaScript的缩写,它是一种广泛使用的编程语言。JavaScript通常用于在web页面中添加动态内容、交互式特效和用户体验增强等功能。它是一种脚本语言,可以在浏览器中直接运行,也可以与服务器端进行交互。JavaScript可以用于创建复杂的应用程序,包括网页、手机应用、桌面应用以及游戏等。它具有广泛的应用领域,并且拥有大量的开发资源和社区支持。
385 4

热门文章

最新文章