近年来,谷歌于 2018 年推出的 JAX 迎来了迅猛发展,很多研究者对其寄予厚望,希望它可以取代 TensorFlow 等众多深度学习框架。但 JAX 是否真的适合所有人使用呢?这篇文章对 JAX 的方方面面展开了深入探讨,希望可以给研究者选择深度学习框架时提供有益的参考。
自 2018 年底推出以来,JAX 的受欢迎程度一直在稳步提升。2020 年,DeepMind 宣布使用 JAX 来加速其研究。越来越多来自谷歌大脑(Google Brain)和其他机构的项目也都在使用 JAX。
目前,在 JAX 的 GitHub 项目主页,Star 量已经达到了 16.3k。
项目地址:https://github.com/google/jax
JAX 是一个非常有前途的项目,并且用户一直在稳步增长。JAX 已经在深度学习、机器人 / 控制系统、贝叶斯方法和科学模拟等诸多领域得到了广泛应用。
如此,是否意味着 JAX 也将成为下一个大型深度学习框架?近日,发表在 AssemblyAI 博客上的文章《Why You Should (or Shouldn't) Be Using JAX in 2022》中,作者 Ryan O'Connor 为我们深入解读了 JAX 的概念、使用 JAX 的理由以及是否应该使用 JAX 等。
JAX 简介
JAX 不是一个深度学习框架或库,其设计初衷也不是成为一个深度学习框架或库。简而言之,JAX 是一个包含可组合函数转换的数值计算库。正如我们所看到的,深度学习只是 JAX 功能的一小部分:
JAX 的定位科学计算(Scientific Computing)和函数转换(Function Transformations)的交叉融合,具有除训练深度学习模型以外的一系列能力,包括如下:
- 即时编译(Just-in-Time Compilation)
- 自动并行化(Automatic Parallelization)
- 自动向量化(Automatic Vectorization)
- 自动微分(Automatic Differentiation)
使用 JAX 的原因有哪些?
简而言之,是速度。这是 JAX 与任何用例相关的一种通用能力。让我们使用 NumPy 和 JAX 对矩阵的前三个幂求和(按元素)。
首先是 NumPy 实现。我们发现,该计算大约需要 851 毫秒。
然后使用 JAX 实现该计算:
JAX 仅在 5.54 毫秒内执行完成该计算,速度是 NumPy 的 150 倍以上。
JAX 的速度比 NumPy 快了 N 个数量级。需要注意,JAX 使用的是 TPU,NumPy 使用了 CPU,以此强调 JAX 的速度上限远高于 NumPy。
作者列出了以下六条可能想要使用 JAX 的理由:
- NumPy 加速器。NumPy 是使用 Python 进行科学计算的基础包之一,但它仅与 CPU 兼容。JAX 提供了 NumPy 的实现(具有几乎相同的 API),可以非常轻松地在 GPU 和 TPU 上运行。对于许多用户而言,仅此一项功能就足以证明使用 JAX 的合理性;
- XLA。XLA(Accelerated Linear Algebra)是专为线性代数设计的全程序优化编译器。JAX 建立在 XLA 之上,显著提高了计算速度上限;
- JIT。JAX 允许用户使用 XLA 将自己的函数转换为即时编译(JIT)版本。这意味着可以通过在计算函数中添加一个简单的函数装饰器(decorator)来将计算速度提高几个数量级;
- Auto-differentiation。JAX 将 Autograd(自动区分原生 Python 代码和 NumPy 代码)和 XLA 结合在一起,它的自动微分能力在科学计算的许多领域都至关重要。JAX 提供了几个强大的自动微分工具;
- 深度学习。虽然 JAX 本身不是深度学习框架,但它的确为深度学习提供了一个很好的基础。很多构建在 JAX 之上的库旨在提供深度学习功能,包括 Flax、Haiku 和 Elegy。甚至在最近的一些 PyTorch 与 TensorFlow 文章中强调了 JAX 作为一个值得关注的「框架」,并推荐其用于基于 TPU 的深度学习研究。JAX 对 Hessians 的高效计算也与深度学习相关,因为它们使高阶优化技术更加可行;
- 通用可微分编程范式(General Differentiable Programming Paradigm )。虽然我们可以使用 JAX 来构建和训练深度学习模型,但它也为通用可微编程提供了一个框架。这意味着 JAX 可以通过使用基于模型的机器学习方法来解决问题,从而可以利用数十年研究建立起的给定领域的先验知识。