谷歌开源计算框架JAX:比Numpy快30倍,还可在TPU上运行!

简介: 大家有了解过JAX吗?JAX是一种可在CPU、GPU和TPU上运行的“Numpy”,专门针对机器学习研究,并提供高性能自微分计算能力,速度要比纯用Numpy快几十倍!

 微信图片_20220112140819.png


相信大家对numpy, Tensorflow, Pytorch已经极其熟悉,不过,你知道JAX吗? JAX发布之后,有网友进行了测试,发现,使用JAX,Numpy运算可以快三十多倍! 


下面是使用Numpy的运行情况:


import numpy as np  # 使用标准numpy,运算将在CPU上执行。
x = np.random.random([5000, 5000]).astype(np.float32)
%timeit np.matmul(x, x)


运行结果:


1 loop, best of 3: 3.9 s per loop 而下面是使用JAX的Numpy的情况:
import jax.numpy as np # 使用"JAX版"的numpy from jax import random # 注意JAX下随机数API有所不同 x = random.uniform(random.PRNGKey(0), [5000, 5000]) %timeit np.matmul(x, x)


运行情况:


1 loop, best of 3: 109 ms per loop


我们可以发现,使用原始numpy,运行时间大概为3.9s,而使用JAX的numpy,运行时间仅仅只有0.109s,速度上直接提升了三十多倍! 


是不是很神奇? 那JAX到底是什么?


小编我就不卖关子了:


JAX是谷歌开源的、可以在CPU、GPU和TPU上运行的numpy,是针对机器学习研究的高性能自微分计算框架。


简单来说,就是GPU和TPU加速、支持自动微分(autodiff)的numpy。


 微信图片_20220112140824.png


快速入门链接:https://jax.readthedocs.io/en/latest/notebooks/quickstart.html


 我们都知道,numpy是Python下的基础数值运算库,应用非常广泛,如果要用Python进行科学计算或者机器学习,没人能够离得了它。


 但是,numpy并不支持GPU或者其他硬件加速器,也缺少对backpropagation的内置支持,此外,Python自身也有速度限制,


因此,在生产环境下使用numpy训练或者部署深度学习模型的人很少。 不过numpy也有它独特的魅力:底层、灵活、调试方便、API稳定且为大家所熟悉,从而深受研究者的喜爱。


 而JAX的主要出发点就是将numpy的以上优势与硬件加速结合,与依赖于预编译核和快速C++代码的Pytorch相比,JAX可以让我们能够在高级接口使用自己最喜欢的加速器进行编写。 在最高层,JAX结合了XLA&Autograd,来加速用户开发的基于线性代数的项目。 


Github项目地址:https://github.com/google/jax


微信图片_20220112140826.png


此外,入门JAX的过程非常自然简单——许多人每天都在处理numpy的语法和规定,而JAX则大大减少了用户的这些烦恼。 


目前,JAX支持在Linux (Ubuntu 16.04或更高版本)和macOS(10.12或更高版本)平台上安装或构建,Windows用户可以通过Windows的Linux子系统在CPU和GPU上使用JAX。 


参考链接:


https://roberttlange.github.io/posts/2021/02/cma-es-jax/

https://roberttlange.github.io/posts/2020/03/blog-post-10/

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

https://www.zhihu.com/question/306496943/answer/1041519580

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
24天前
|
数据挖掘 数据处理 Python
《Numpy 简易速速上手小册》第5章:Numpy高效计算与广播(2024 最新版)
《Numpy 简易速速上手小册》第5章:Numpy高效计算与广播(2024 最新版)
31 0
|
24天前
|
数据采集 机器学习/深度学习 算法
《Numpy 简易速速上手小册》第4章:Numpy 数学和统计计算(2024 最新版)
《Numpy 简易速速上手小册》第4章:Numpy 数学和统计计算(2024 最新版)
29 0
|
4月前
|
存储 大数据 索引
【Python】NumPy数组和矢量计算
【1月更文挑战第26天】【Python】NumPy数组和矢量计算
|
5月前
|
机器学习/深度学习 PyTorch TensorFlow
JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学
JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学
52 0
|
Python
python计算的效率问题-pandas、numpy结合代替遍历pandas数据
python计算的效率问题-pandas、numpy结合代替遍历pandas数据
83 0
python计算的效率问题-pandas、numpy结合代替遍历pandas数据
python_numpy_计算对数收益率和还原问题
python_numpy_计算对数收益率和还原问题
105 0
|
Python
numpy向量计算
numpy向量计算
75 0
|
机器学习/深度学习 存储 Serverless
NumPy 与 Python 内置列表计算标准差的区别
NumPy,是 Numerical Python 的简称,用于高性能科学计算和数据分析的基础包,像数学科学工具(pandas)和框架(Scikit-learn)中都使用到了 NumPy 这个包。
|
数据挖掘 Python
Python | Numpy:详解计算矩阵的均值和标准差
对于 CRITIC 权重法而言,在标准差一定时,指标间冲突性越小,权重也越小;冲突性越大,权重也越大;另外,当两个指标间的正相关程度越大时,(相关系数越接近1),冲突性越小,这表明这两个指标在评价方案的优劣上反映的信息有较大的相似性。
328 0
Python | Numpy:详解计算矩阵的均值和标准差
|
算法 Python 计算机视觉
用numpy计算成交量加权平均价格(VWAP),并实现读写文件
VWAP(Volume-Weighted Average Price,成交量加权平均价格)是一个非常重要的经济学量,它代表着金融资产的“平均”价格。某个价格的成交量越高,该价格所占的权重就越大。VWAP就是以成交量为权重计算出来的加权平均值,常用于算法交易。
2583 0