谷歌开源计算框架JAX:比Numpy快30倍,还可在TPU上运行!

简介: 大家有了解过JAX吗?JAX是一种可在CPU、GPU和TPU上运行的“Numpy”,专门针对机器学习研究,并提供高性能自微分计算能力,速度要比纯用Numpy快几十倍!

 微信图片_20220112140819.png


相信大家对numpy, Tensorflow, Pytorch已经极其熟悉,不过,你知道JAX吗? JAX发布之后,有网友进行了测试,发现,使用JAX,Numpy运算可以快三十多倍! 


下面是使用Numpy的运行情况:


import numpy as np  # 使用标准numpy,运算将在CPU上执行。
x = np.random.random([5000, 5000]).astype(np.float32)
%timeit np.matmul(x, x)


运行结果:


1 loop, best of 3: 3.9 s per loop 而下面是使用JAX的Numpy的情况:
import jax.numpy as np # 使用"JAX版"的numpy from jax import random # 注意JAX下随机数API有所不同 x = random.uniform(random.PRNGKey(0), [5000, 5000]) %timeit np.matmul(x, x)


运行情况:


1 loop, best of 3: 109 ms per loop


我们可以发现,使用原始numpy,运行时间大概为3.9s,而使用JAX的numpy,运行时间仅仅只有0.109s,速度上直接提升了三十多倍! 


是不是很神奇? 那JAX到底是什么?


小编我就不卖关子了:


JAX是谷歌开源的、可以在CPU、GPU和TPU上运行的numpy,是针对机器学习研究的高性能自微分计算框架。


简单来说,就是GPU和TPU加速、支持自动微分(autodiff)的numpy。


 微信图片_20220112140824.png


快速入门链接:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html


 我们都知道,numpy是Python下的基础数值运算库,应用非常广泛,如果要用Python进行科学计算或者机器学习,没人能够离得了它。


 但是,numpy并不支持GPU或者其他硬件加速器,也缺少对backpropagation的内置支持,此外,Python自身也有速度限制,


因此,在生产环境下使用numpy训练或者部署深度学习模型的人很少。 不过numpy也有它独特的魅力:底层、灵活、调试方便、API稳定且为大家所熟悉,从而深受研究者的喜爱。


 而JAX的主要出发点就是将numpy的以上优势与硬件加速结合,与依赖于预编译核和快速C++代码的Pytorch相比,JAX可以让我们能够在高级接口使用自己最喜欢的加速器进行编写。 在最高层,JAX结合了XLA&Autograd,来加速用户开发的基于线性代数的项目。 


Github项目地址:https://github.com/google/jax


微信图片_20220112140826.png


此外,入门JAX的过程非常自然简单——许多人每天都在处理numpy的语法和规定,而JAX则大大减少了用户的这些烦恼。 


目前,JAX支持在Linux (Ubuntu 16.04或更高版本)和macOS(10.12或更高版本)平台上安装或构建,Windows用户可以通过Windows的Linux子系统在CPU和GPU上使用JAX。 


参考链接:


https://roberttlange.github.io/posts/2021/02/cma-es-jax/

https://roberttlange.github.io/posts/2020/03/blog-post-10/

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://www.zhihu.com/question/306496943/answer/1041519580

相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
相关文章
|
机器学习/深度学习 并行计算 大数据
【Python篇】NumPy完整指南(上篇):掌握数组、矩阵与高效计算的核心技巧2
【Python篇】NumPy完整指南(上篇):掌握数组、矩阵与高效计算的核心技巧
396 10
|
索引 Python
【Python篇】NumPy完整指南(上篇):掌握数组、矩阵与高效计算的核心技巧1
【Python篇】NumPy完整指南(上篇):掌握数组、矩阵与高效计算的核心技巧
432 4
|
机器学习/深度学习 并行计算 调度
CuPy:将 NumPy 数组调度到 GPU 上运行
CuPy:将 NumPy 数组调度到 GPU 上运行
531 1
|
PyTorch 算法框架/工具 Python
Pytorch学习笔记(十):Torch对张量的计算、Numpy对数组的计算、它们之间的转换
这篇文章是关于PyTorch张量和Numpy数组的计算方法及其相互转换的详细学习笔记。
322 0
|
分布式计算 并行计算 大数据
NumPy 并行计算与分布式部署
【8月更文第30天】随着数据量的不断增长,传统的单机计算模型已经难以满足对大规模数据集处理的需求。并行和分布式计算成为了处理这些大数据集的关键技术。虽然 NumPy 本身并不直接支持并行计算,但可以通过结合其他库如 Numba 和 Dask 来实现高效的并行和分布式计算。
215 1
|
SQL 并行计算 API
Dask是一个用于并行计算的Python库,它提供了类似于Pandas和NumPy的API,但能够在大型数据集上进行并行计算。
Dask是一个用于并行计算的Python库,它提供了类似于Pandas和NumPy的API,但能够在大型数据集上进行并行计算。
|
机器学习/深度学习 C语言 索引
数组计算模块NumPy(一)
NumPy是Python科学计算的核心库,提供高性能的数组和矩阵操作,支持大量数学函数。它包括一维、二维到多维数组,并通过C实现,优化了计算速度。
数组计算模块NumPy(一)
|
存储 数据挖掘 API
【NumPy基础】- Numpy数组和矢量计算
【NumPy基础】- Numpy数组和矢量计算
150 4
|
索引 Python
数组计算模块NumPy(二)
NumPy教程概要:介绍数组切片、二维数组索引、重塑、转置和数组操作。讨论了切片语法`[start:stop:step]`,二维数组的索引方式,以及reshape方法改变数组形状。涉及转置通过`.T`属性或`transpose()`函数实现,数组增加使用`hstack()`和`vstack()`,删除用`delete()`。还提到了矩阵运算,包括加减乘除,并展示了`numpy.dot()`和`@`运算符的使用。最后提到了排序函数`sort()`、`argsort()`和`lexsort()`,以及NumPy的统计分析函数如均值、标准差等。
|
数据挖掘 数据处理 Python
《Numpy 简易速速上手小册》第5章:Numpy高效计算与广播(2024 最新版)
《Numpy 简易速速上手小册》第5章:Numpy高效计算与广播(2024 最新版)
145 0
下一篇
oss云网关配置