切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU(1)

简介: 切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU

切换JAX,强化学习速度提升4000倍!牛津大学开源框架PureJaxRL,训练只需GPU

新智元新智元 2023-04-14 16:08 发表于北京



 新智元报道  

编辑:LRS

【新智元导读】加入光荣的JAX-强化学习进化!


还在为强化学习运行效率发愁?无法解释强化学习智能体的行为?


最近来自牛津大学Foerster Lab for AI Research(FLAIR)的研究人员分享了一篇博客,介绍了如何使用JAX框架仅利用GPU来高效运行强化学习算法,实现了超过4000倍的加速;并利用超高的性能,实现元进化发现算法,更好地理解强化学习算法。


文章链接:https://chrislu.page/blog/meta-disco/代码链接:https://github.com/luchris429/purejaxrl


作者团队开发的框架PureJaxRL可以极大降低进入Deep RL研究的算力需求,使学术实验室能够使用数万亿帧进行研究(缩小了与工业研究实验室的差距),独立研究人员也可以利用单个GPU进行开发。


文章作者Chris Lu是牛津大学博士二年级学生,工作重点是将进化启发(evolution-inspired)的技术应用于元学习和多智能体强化学习,曾在DeepMind实习。


使用PureJaxRL实现超过4000倍加速


GPU is all you need
大多数Deep RL的算法同时需要CPU和GPU的计算资源,通常来说,环境(environment)在CPU上运行,策略神经网络运行在GPU上,为了提高wall clock速度,开发者往往使用多个线程并行运行多个环境。

但如果是用JAX的话,可以直接将环境向量化(vectorise),并将其在GPU上运行,而无需使用CPU的多线程。

不仅可以避免在CPU和GPU之间传输数据以节省时间,如果使用JAX原语来编写环境程序,还可以使用JAX强大的vmap函数来立即创建环境的矢量化版本。

虽然在JAX中重写RL环境可能很费时间,但幸运的是,目前已经有一些库提供了各种环境:

Gymnax库包括了多个常用的环境,包括经典的控制任务,Bsuite任务和Minatar(类似Atari的)环境。


链接:https://github.com/RobertTLange/gymnax


研究人员选择Gymnax作为测试和评估代码的首选库,在这篇文章中的示例用的也是Gymnax,库中还包括许多其他非常简洁的功能,并且非常容易使用。


Brax是使用JAX运行类似Mujoco的连续控制环境的方法,该库包含许多强化学习环境,可以对标类似的连续控制环境,如HalfCheetah和Humanoid,并且也是可微的!


链接:https://github.com/google/brax


Jumanji包含许多令人特别炫酷、简单和行业驱动的环境,库中的许多环境都直接来自于行业设置,确保这里提供的环境是实用的并且与现实世界相关,具体问题包括组合问题,如著名的旅行推销员问题或3D装箱。


链接:https://github.com/instadeepai/jumanji


Pgx有许多流行的桌面游戏和其他环境,包括Connect 4、围棋、扑克!


链接:https://github.com/sotetsuk/pgx


在Gymnax的测速基线报告显示,如果用numpy使用CartPole-v1在10个环境并行运行的情况下,需要46秒才能达到100万帧;在A100上使用Gymnax,在2k 环境下并行运行只需要0.05秒,加速达到1000倍


这个结论也适用于比Cartpole-v1更复杂的环境,例如Minatar-Breakout需要50秒才能在 CPU 上达到100万帧,而在 Gymnax 只需要0.2秒。


这些实验结果显示了多个数量级的改进,使学术研究人员能够在有限的硬件上高效地运行超过数万亿帧的实验。



在JAX中端到端地进行所有操作有几个优势:


  • 在加速器上的矢量化环境运行速度更快。
  • 通过将计算完全保留在GPU上,可以避免在CPU和GPU之间来回复制数据的开销,通常也是性能的一个关键瓶颈。
  • 通过JIT编译实现,可以避免Python的开销,有时会阻塞发送命令之间的GPU 计算。
  • JIT 编译通过运算符融合(operator fusion)可以获得显著的加速效果,即优化了GPU上的内存使用。
  • 多线程的并行运行环境很难调试,并且会导致复杂的基础设施。


为了证明这些优势,作者在纯JAX环境中复制了CleanRL的PyTorch PPO基线实现,使用了相同数量的并行环境和相同的超参数设置,并且没有利用海量环境矢量化的优势。


在Cartpole-v1和MinAtar-Breakout中运行5次,训练过程如下。


Cartpole-v1和 MinAtar-Breakout上的CleanRLvs JAX PPO,给定相同的超参数和帧数,得到了几乎相同的结果。


将x轴从帧替换为wall-clock time(某个线程上实际执行的时间)后,在没有任何额外并行环境的情况下,速度提升了10倍以上。


Cartpole-v1和 MinAtar-Breakout上的CleanRL vs JAX PPO,得到了相同的结果,但是快了10倍以上!


并行运行多个智能体


虽然可以从上述技巧中得到相当不错的加速效果,但与标题中的4000倍加速仍然相去甚远。


通过向量化整个强化学习训练循环以及之前提到JAX中的vmap,可以很容易地并行训练多个智能体。






rng = jax.random.PRNGKey(42)rngs = jax.random.split(rng, 256)train_vjit = jax.jit(jax.vmap(make_train(config)))outs = train_vjit(rngs)


此外,还可以使用JAX提供的pmap函数在多个 GPU 上运行,在此之前,这种跨设备的并行化和向量化,尤其是在设备内部的并行化和向量化,是一个非常令人头疼的问题。


Cartpole-v1和 MinAtar-Breakout 上的CleanRL vs Jax PPO,可以将智能体训练本身并行化。在 Cartpole-v1上,只需要用训练一个CleanRL智能体的一半时间来训练2048个智能体。


如果正在开发一个新的强化学习算法,那么就可以在单个GPU上同时对具有统计学意义的大量种子进行快速训练。


除此之外,还可以同时训练成千上万的独立智能体,在作者提供的代码中,还展示了如何使用进行快速超参数搜索,也可以将其用于进化元学习!



相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
8月前
|
机器学习/深度学习 存储 PyTorch
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
【AMP实操】解放你的GPU运行内存!在pytorch中使用自动混合精度训练
284 0
|
8月前
|
机器学习/深度学习 异构计算 Python
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
对于深度学习初学者来说,JupyterNoteBook的脚本运行形式显然更加友好,依托Python语言的跨平台特性,JupyterNoteBook既可以在本地线下环境运行,也可以在线上服务器上运行。GoogleColab作为免费GPU算力平台的执牛耳者,更是让JupyterNoteBook的脚本运行形式如虎添翼。 本次我们利用Bert-vits2的最终版Bert-vits2-v2.3和JupyterNoteBook的脚本来复刻生化危机6的人气角色艾达王(ada wong)。
Bert-vits2最终版Bert-vits2-2.3云端训练和推理(Colab免费GPU算力平台)
|
3月前
|
并行计算 Shell TensorFlow
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
在使用TensorFlow-GPU训练MTCNN时,如果遇到“Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED”错误,通常是由于TensorFlow、CUDA和cuDNN版本不兼容或显存分配问题导致的,可以通过安装匹配的版本或在代码中设置动态显存分配来解决。
62 1
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
|
3月前
|
人工智能 语音技术 UED
仅用4块GPU、不到3天训练出开源版GPT-4o,这是国内团队最新研究
【10月更文挑战第19天】中国科学院计算技术研究所提出了一种名为LLaMA-Omni的新型模型架构,实现与大型语言模型(LLMs)的低延迟、高质量语音交互。该模型集成了预训练的语音编码器、语音适配器、LLM和流式语音解码器,能够在不进行语音转录的情况下直接生成文本和语音响应,显著提升了用户体验。实验结果显示,LLaMA-Omni的响应延迟低至226ms,具有创新性和实用性。
82 1
|
5月前
|
机器学习/深度学习 并行计算 PyTorch
GPU 加速与 PyTorch:最大化硬件性能提升训练速度
【8月更文第29天】GPU(图形处理单元)因其并行计算能力而成为深度学习领域的重要组成部分。本文将介绍如何利用PyTorch来高效地利用GPU进行深度学习模型的训练,从而最大化训练速度。我们将讨论如何配置环境、选择合适的硬件、编写高效的代码以及利用高级特性来提高性能。
908 1
|
5月前
|
并行计算 算法 调度
自研分布式训练框架EPL问题之提高GPU利用率如何解决
自研分布式训练框架EPL问题之提高GPU利用率如何解决
|
7月前
|
机器学习/深度学习 自然语言处理 异构计算
单GPU训练一天,Transformer在100位数字加法上就达能到99%准确率
【6月更文挑战第11天】Transformer模型在算术任务上取得重大突破,通过引入Abacus Embeddings,一天内在100位数字加法上达到99%准确率。该嵌入方法帮助模型跟踪数字位置,提升处理长序列的能力。实验还显示,Abacus Embeddings可与其他嵌入方法结合,提升乘法任务性能。然而,模型在更长序列的扩展性和其他类型任务的效果仍有待探究,具体训练技术的影响也需要进一步研究。论文链接:https://arxiv.org/pdf/2405.17399
83 1
|
8月前
|
机器学习/深度学习 弹性计算 自然语言处理
【阿里云弹性计算】深度学习训练平台搭建:阿里云 ECS 与 GPU 实例的高效利用
【5月更文挑战第28天】阿里云ECS结合GPU实例为深度学习提供高效解决方案。通过弹性计算服务满足大量计算需求,GPU加速训练。用户可按需选择实例规格,配置深度学习框架,实现快速搭建训练平台。示例代码展示了在GPU实例上使用TensorFlow进行训练。优化包括合理分配GPU资源和使用混合精度技术,应用涵盖图像识别和自然语言处理。注意成本控制及数据安全,借助阿里云推动深度学习发展。
282 2
|
8月前
|
机器学习/深度学习 人工智能 算法
为什么大模型训练需要GPU,以及适合训练大模型的GPU介绍
为什么大模型训练需要GPU,以及适合训练大模型的GPU介绍
327 1
|
8月前
|
机器学习/深度学习 并行计算 PyTorch
【多GPU炼丹-绝对有用】PyTorch多GPU并行训练:深度解析与实战代码指南
本文介绍了PyTorch中利用多GPU进行深度学习的三种策略:数据并行、模型并行和两者结合。通过`DataParallel`实现数据拆分、模型不拆分,将数据批次在不同GPU上处理;数据不拆分、模型拆分则将模型组件分配到不同GPU,适用于复杂模型;数据和模型都拆分,适合大型模型,使用`DistributedDataParallel`结合`torch.distributed`进行分布式训练。代码示例展示了如何在实践中应用这些策略。
2087 2
【多GPU炼丹-绝对有用】PyTorch多GPU并行训练:深度解析与实战代码指南