切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU(1)

简介: 切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU

切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU

新智元新智元 2023-04-14 16:08 发表于北京



 新智元报道  

编辑:LRS

【新智元导读】加入光荣的JAX-强化学习进化!


还在为强化学习运行效率发愁?无法解释强化学习智能体的行为?


最近来自牛津大学Foerster Lab for AI Research(FLAIR)的研究人员分享了一篇博客,介绍了如何使用JAX框架仅利用GPU来高效运行强化学习算法,实现了超过4000倍的加速;并利用超高的性能,实现元进化发现算法,更好地理解强化学习算法。


文章链接:https://chrislu.page/blog/meta-disco/代码链接:https://github.com/luchris429/purejaxrl


作者团队开发的框架PureJaxRL可以极大降低进入Deep RL研究的算力需求,使学术实验室能够使用数万亿帧进行研究(缩小了与工业研究实验室的差距),独立研究人员也可以利用单个GPU进行开发。


文章作者Chris Lu是牛津大学博士二年级学生,工作重点是将进化启发(evolution-inspired)的技术应用于元学习和多智能体强化学习,曾在DeepMind实习。


使用PureJaxRL实现超过4000倍加速


GPU is all you need
大多数Deep RL的算法同时需要CPU和GPU的计算资源,通常来说,环境(environment)在CPU上运行,策略神经网络运行在GPU上,为了提高wall clock速度,开发者往往使用多个线程并行运行多个环境。

但如果是用JAX的话,可以直接将环境向量化(vectorise),并将其在GPU上运行,而无需使用CPU的多线程。

不仅可以避免在CPU和GPU之间传输数据以节省时间,如果使用JAX原语来编写环境程序,还可以使用JAX强大的vmap函数来立即创建环境的矢量化版本。

虽然在JAX中重写RL环境可能很费时间,但幸运的是,目前已经有一些库提供了各种环境:

Gymnax库包括了多个常用的环境,包括经典的控制任务,Bsuite任务和Minatar(类似Atari的)环境。


链接:https://github.com/RobertTLange/gymnax


研究人员选择Gymnax作为测试和评估代码的首选库,在这篇文章中的示例用的也是Gymnax,库中还包括许多其他非常简洁的功能,并且非常容易使用。


Brax是使用JAX运行类似Mujoco的连续控制环境的方法,该库包含许多强化学习环境,可以对标类似的连续控制环境,如HalfCheetah和Humanoid,并且也是可微的!


链接:https://github.com/google/brax


Jumanji包含许多令人特别炫酷、简单和行业驱动的环境,库中的许多环境都直接来自于行业设置,确保这里提供的环境是实用的并且与现实世界相关,具体问题包括组合问题,如著名的旅行推销员问题或3D装箱。


链接:https://github.com/instadeepai/jumanji


Pgx有许多流行的桌面游戏和其他环境,包括Connect 4、围棋、扑克!


链接:https://github.com/sotetsuk/pgx


在Gymnax的测速基线报告显示,如果用numpy使用CartPole-v1在10个环境并行运行的情况下,需要46秒才能达到100万帧;在A100上使用Gymnax,在2k 环境下并行运行只需要0.05秒,加速达到1000倍


这个结论也适用于比Cartpole-v1更复杂的环境,例如Minatar-Breakout需要50秒才能在 CPU 上达到100万帧,而在 Gymnax 只需要0.2秒。


这些实验结果显示了多个数量级的改进,使学术研究人员能够在有限的硬件上高效地运行超过数万亿帧的实验。



在JAX中端到端地进行所有操作有几个优势:


  • 在加速器上的矢量化环境运行速度更快。
  • 通过将计算完全保留在GPU上,可以避免在CPU和GPU之间来回复制数据的开销,通常也是性能的一个关键瓶颈。
  • 通过JIT编译实现,可以避免Python的开销,有时会阻塞发送命令之间的GPU 计算。
  • JIT 编译通过运算符融合(operator fusion)可以获得显著的加速效果,即优化了GPU上的内存使用。
  • 多线程的并行运行环境很难调试,并且会导致复杂的基础设施。


为了证明这些优势,作者在纯JAX环境中复制了CleanRL的PyTorch PPO基线实现,使用了相同数量的并行环境和相同的超参数设置,并且没有利用海量环境矢量化的优势。


在Cartpole-v1和MinAtar-Breakout中运行5次,训练过程如下。


Cartpole-v1和 MinAtar-Breakout上的CleanRLvs JAX PPO,给定相同的超参数和帧数,得到了几乎相同的结果。


将x轴从帧替换为wall-clock time(某个线程上实际执行的时间)后,在没有任何额外并行环境的情况下,速度提升了10倍以上。


Cartpole-v1和 MinAtar-Breakout上的CleanRL vs JAX PPO,得到了相同的结果,但是快了10倍以上!


并行运行多个智能体


虽然可以从上述技巧中得到相当不错的加速效果,但与标题中的4000倍加速仍然相去甚远。


通过向量化整个强化学习训练循环以及之前提到JAX中的vmap,可以很容易地并行训练多个智能体。






rng = jax.random.PRNGKey(42)rngs = jax.random.split(rng, 256)train_vjit = jax.jit(jax.vmap(make_train(config)))outs = train_vjit(rngs)


此外,还可以使用JAX提供的pmap函数在多个 GPU 上运行,在此之前,这种跨设备的并行化和向量化,尤其是在设备内部的并行化和向量化,是一个非常令人头疼的问题。


Cartpole-v1和 MinAtar-Breakout 上的CleanRL vs Jax PPO,可以将智能体训练本身并行化。在 Cartpole-v1上,只需要用训练一个CleanRL智能体的一半时间来训练2048个智能体。


如果正在开发一个新的强化学习算法,那么就可以在单个GPU上同时对具有统计学意义的大量种子进行快速训练。


除此之外,还可以同时训练成千上万的独立智能体,在作者提供的代码中,还展示了如何使用进行快速超参数搜索,也可以将其用于进化元学习!



相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
3月前
|
机器学习/深度学习 存储 PyTorch
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
69 0
|
5月前
|
机器学习/深度学习 弹性计算 TensorFlow
阿里云GPU加速:大模型训练与推理的全流程指南
随着深度学习和大规模模型的普及,GPU成为训练和推理的关键加速器。本文将详细介绍如何利用阿里云GPU产品完成大模型的训练与推理。我们将使用Elastic GPU、阿里云深度学习镜像、ECS(云服务器)等阿里云产品,通过代码示例和详细说明,带你一步步完成整个流程。
913 0
|
5月前
|
机器学习/深度学习 异构计算 Python
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
对于深度学习初学者来说,JupyterNoteBook的脚本运行形式显然更加友好,依托Python语言的跨平台特性,JupyterNoteBook既可以在本地线下环境运行,也可以在线上服务器上运行。GoogleColab作为免费GPU算力平台的执牛耳者,更是让JupyterNoteBook的脚本运行形式如虎添翼。 本次我们利用Bert-vits2的最终版Bert-vits2-v2.3和JupyterNoteBook的脚本来复刻生化危机6的人气角色艾达王(ada wong)。
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
|
2月前
|
机器学习/深度学习 并行计算 PyTorch
【多GPU炼丹-绝对有用】PyTorch多GPU并行训练:深度解析与实战代码指南
本文介绍了PyTorch中利用多GPU进行深度学习的三种策略:数据并行、模型并行和两者结合。通过`DataParallel`实现数据拆分、模型不拆分,将数据批次在不同GPU上处理;数据不拆分、模型拆分则将模型组件分配到不同GPU,适用于复杂模型;数据和模型都拆分,适合大型模型,使用`DistributedDataParallel`结合`torch.distributed`进行分布式训练。代码示例展示了如何在实践中应用这些策略。
64 2
【多GPU炼丹-绝对有用】PyTorch多GPU并行训练:深度解析与实战代码指南
|
6月前
|
存储 人工智能 芯片
多GPU训练大型模型:资源分配与优化技巧 | 英伟达将推出面向中国的改良芯片HGX H20、L20 PCIe、L2 PCIe
在人工智能领域,大型模型因其强大的预测能力和泛化性能而备受瞩目。然而,随着模型规模的不断扩大,计算资源和训练时间成为制约其发展的重大挑战。特别是在英伟达禁令之后,中国AI计算行业面临前所未有的困境。为了解决这个问题,英伟达将针对中国市场推出新的AI芯片,以应对美国出口限制。本文将探讨如何在多个GPU上训练大型模型,并分析英伟达禁令对中国AI计算行业的影响。
|
5月前
|
机器学习/深度学习 缓存 PyTorch
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
385 0
|
5月前
|
机器学习/深度学习 人工智能 弹性计算
阿里云林立翔:基于阿里云GPU的AIGC小规模训练优化方案
阿里云弹性计算林立翔在【AIGC】话题下带来了题为《基于阿里云GPU的AIGC小规模训练优化方案》的主题演讲,围绕生成式AI技术栈、生成式AI微调训练和性能分析、ECS GPU实例为生成式AI提供算力保障、应用场景案例等相关话题展开。
|
5月前
|
XML 数据格式 异构计算
笔记 ubuntu18.04安装cuda10.2 cudnn7.5,然后进行物体检测gpu训练
笔记 ubuntu18.04安装cuda10.2 cudnn7.5,然后进行物体检测gpu训练
48 1
|
7月前
|
弹性计算 自然语言处理 数据安全/隐私保护
GPU实验室-通过GPU云服务器训练GPT-2
本文介绍如何使用GPU云服务器,使用Megatron-Deepspeed框架训练GPT-2模型并生成文本。
GPU实验室-通过GPU云服务器训练GPT-2
|
8月前
|
存储 人工智能 Cloud Native
云原生AI套件:一键训练大模型及部署GPU共享推理服务
本实验指导您如何基于容器服务ACK,使用云原生AI套件提交Bloom模型的微调训练作业,并使用GPU共享能力部署推理服务。
1106 0

热门文章

最新文章