切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU
新智元新智元 2023-04-14 16:08 发表于北京
新智元报道
编辑:LRS
【新智元导读】加入光荣的JAX-强化学习进化!
还在为强化学习运行效率发愁?无法解释强化学习智能体的行为?
最近来自牛津大学Foerster Lab for AI Research(FLAIR)的研究人员分享了一篇博客,介绍了如何使用JAX框架仅利用GPU来高效运行强化学习算法,实现了超过4000倍的加速;并利用超高的性能,实现元进化发现算法,更好地理解强化学习算法。
文章链接:https://chrislu.page/blog/meta-disco/代码链接:https://github.com/luchris429/purejaxrl
作者团队开发的框架PureJaxRL可以极大降低进入Deep RL研究的算力需求,使学术实验室能够使用数万亿帧进行研究(缩小了与工业研究实验室的差距),独立研究人员也可以利用单个GPU进行开发。
文章作者Chris Lu是牛津大学博士二年级学生,工作重点是将进化启发(evolution-inspired)的技术应用于元学习和多智能体强化学习,曾在DeepMind实习。
使用PureJaxRL实现超过4000倍加速
GPU is all you need
大多数Deep RL的算法同时需要CPU和GPU的计算资源,通常来说,环境(environment)在CPU上运行,策略神经网络运行在GPU上,为了提高wall clock速度,开发者往往使用多个线程并行运行多个环境。
但如果是用JAX的话,可以直接将环境向量化(vectorise),并将其在GPU上运行,而无需使用CPU的多线程。
不仅可以避免在CPU和GPU之间传输数据以节省时间,如果使用JAX原语来编写环境程序,还可以使用JAX强大的vmap函数来立即创建环境的矢量化版本。
虽然在JAX中重写RL环境可能很费时间,但幸运的是,目前已经有一些库提供了各种环境:
Gymnax库包括了多个常用的环境,包括经典的控制任务,Bsuite任务和Minatar(类似Atari的)环境。
链接:https://github.com/RobertTLange/gymnax
研究人员选择Gymnax作为测试和评估代码的首选库,在这篇文章中的示例用的也是Gymnax,库中还包括许多其他非常简洁的功能,并且非常容易使用。
Brax是使用JAX运行类似Mujoco的连续控制环境的方法,该库包含许多强化学习环境,可以对标类似的连续控制环境,如HalfCheetah和Humanoid,并且也是可微的!
链接:https://github.com/google/brax
Jumanji包含许多令人特别炫酷、简单和行业驱动的环境,库中的许多环境都直接来自于行业设置,确保这里提供的环境是实用的并且与现实世界相关,具体问题包括组合问题,如著名的旅行推销员问题或3D装箱。
链接:https://github.com/instadeepai/jumanji
Pgx有许多流行的桌面游戏和其他环境,包括Connect 4、围棋、扑克!
链接:https://github.com/sotetsuk/pgx
在Gymnax的测速基线报告显示,如果用numpy使用CartPole-v1在10个环境并行运行的情况下,需要46秒才能达到100万帧;在A100上使用Gymnax,在2k 环境下并行运行只需要0.05秒,加速达到1000倍!
这个结论也适用于比Cartpole-v1更复杂的环境,例如Minatar-Breakout需要50秒才能在 CPU 上达到100万帧,而在 Gymnax 只需要0.2秒。
这些实验结果显示了多个数量级的改进,使学术研究人员能够在有限的硬件上高效地运行超过数万亿帧的实验。
在JAX中端到端地进行所有操作有几个优势:
- 在加速器上的矢量化环境运行速度更快。
- 通过将计算完全保留在GPU上,可以避免在CPU和GPU之间来回复制数据的开销,通常也是性能的一个关键瓶颈。
- 通过JIT编译实现,可以避免Python的开销,有时会阻塞发送命令之间的GPU 计算。
- JIT 编译通过运算符融合(operator fusion)可以获得显著的加速效果,即优化了GPU上的内存使用。
- 多线程的并行运行环境很难调试,并且会导致复杂的基础设施。
为了证明这些优势,作者在纯JAX环境中复制了CleanRL的PyTorch PPO基线实现,使用了相同数量的并行环境和相同的超参数设置,并且没有利用海量环境矢量化的优势。
在Cartpole-v1和MinAtar-Breakout中运行5次,训练过程如下。
Cartpole-v1和 MinAtar-Breakout上的CleanRLvs JAX PPO,给定相同的超参数和帧数,得到了几乎相同的结果。
将x轴从帧替换为wall-clock time(某个线程上实际执行的时间)后,在没有任何额外并行环境的情况下,速度提升了10倍以上。
Cartpole-v1和 MinAtar-Breakout上的CleanRL vs JAX PPO,得到了相同的结果,但是快了10倍以上!
并行运行多个智能体
虽然可以从上述技巧中得到相当不错的加速效果,但与标题中的4000倍加速仍然相去甚远。
通过向量化整个强化学习训练循环以及之前提到JAX中的vmap,可以很容易地并行训练多个智能体。
rng = jax.random.PRNGKey(42)rngs = jax.random.split(rng, 256)train_vjit = jax.jit(jax.vmap(make_train(config)))outs = train_vjit(rngs)
此外,还可以使用JAX提供的pmap函数在多个 GPU 上运行,在此之前,这种跨设备的并行化和向量化,尤其是在设备内部的并行化和向量化,是一个非常令人头疼的问题。
Cartpole-v1和 MinAtar-Breakout 上的CleanRL vs Jax PPO,可以将智能体训练本身并行化。在 Cartpole-v1上,只需要用训练一个CleanRL智能体的一半时间来训练2048个智能体。
如果正在开发一个新的强化学习算法,那么就可以在单个GPU上同时对具有统计学意义的大量种子进行快速训练。
除此之外,还可以同时训练成千上万的独立智能体,在作者提供的代码中,还展示了如何使用进行快速超参数搜索,也可以将其用于进化元学习!