【新智元导读】数值计算库JAX自发布以来就备受关注,支持者认为它是「真香」,性能快;但反对者也表示JAX还太年轻,漏洞多,为此最近还引发了一场Reddit大讨论。
自从JAX在2018年底发布以来,受众逐渐增加。随着DeepMind在2020年宣布开始使用JAX来加速研究,越来越多的代码,如来自Google Brain等公司的项目都开始使用JAX。似乎JAX已经是下一个巨头深度学习框架了。最近一位做开发人员教育的从业者Ryan O'Connor发布了一篇文章,详细介绍了JAX是否将取代TensorFlow/PyTorch,并阐述了哪些人群适合在2022年开始使用JAX。
什么是JAX?
JAX并非是一个深度学习的框架或者库,它的设计目标也并非是作为一个新的深度学习框架。简单来说,JAX是一个包含可组合函数变换的数值计算库,只不过深度学习恰好是JAX能做的一项工作。JAX处于函数变换(function transformations)和科学计算的交界处,所以也有能力训练神经网络模型,但不止于训练。JAX最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起,借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。目前JAX在Github上已经斩获了超1.6万颗star,相比之下tensorflow的stars为16万,pyTorch的stars为5.4万,所以想要在深度学习领域超越两位老大哥,路还有很长。
为什么应该了解JAX?
如果一个用户选择了JAX,那基本上只有一个原因:速度。例如同一段函数功能,可以看到numpy的实现需要大概851毫秒。而如果换成JAX,结果缩短到5.54ms,实现了超过numpy 150倍的性能提升!如果画成直方图,性能优势就显得更明显了。而JAX计算更快的原因是使用了TPU,而NumPy只使用CPU。虽然实际使用中并不是「用上JAX,你的程序就会加速150倍」那么简单,但仍然有很多理由来使用它。JAX为科学计算提供了一个通用的基础,它对不同领域的人有不同的用途。从根本上说,如果你在任何与科学计算有关的领域,你都应该了解JAX。作者列出了6个应该使用JAX原因:1. 加速NumPy。NumPy是用Python进行科学计算的基本软件包之一,但它只与CPU兼容。JAX提供了一个NumPy的实现(具有近乎相同的API),可以非常容易地在GPU和TPU上工作。对于许多用户来说,仅仅这一点就足以证明使用JAX的合理性。2. XLA,即加速线性代数(Accelerated Linear Algebra),是一个全程序优化编译器,专门为线性代数设计。JAX是建立在XLA之上的,大大提升了计算速度的上限。3. JIT。JAX允许用户使用XLA将函数转化为JIT(just in time)编译的版本。这意味用户可以通过给计算函数添加一个简单的函数装饰器来提高计算速度,可能是几个数量级的性能提升。4. 自动求导。JAX文档将JAX称为Autograd和XLA的结合体。自动求导的能力在科学计算的许多领域都是至关重要的,而JAX提供了几个强大的自动求导工具。5. 深度学习。虽然JAX本身不是一个深度学习框架,但它肯定为深度学习提供了一个更充分的基础。现在有许多建立在JAX之上的深度学习库,例如Flax、Haiku和Elegy。甚至有研究人员在PyTorch vs TensorFlow文章中强调JAX也是一个值得关注的「框架」,推荐其用于基于TPU的深度学习研究。JAX对Hessians的高效计算也与深度学习有关,因为它们使高阶优化技术更进一步。6. 通用可微分编程范式。虽然可以使用JAX来构建和训练深度学习模型,但它也为通用可微分编程提供了一个框架。这意味着JAX可以通过使用基于模型的机器学习方法来解决实际问题。
2022年,该学JAX吗?
就和所有令人纠结的问题一样,这个问题的答案依然是:It dependes.是否决定把项目迁移到JAX,取决于你的具体情况和目标。如果你对用于通用科学计算的JAX感兴趣,你应该问自己的第一个问题是你是否只是想加速NumPy。如果你的答案是「是」,那么你昨天就应该使用JAX了。如果你不只是在计算数字,而是在参与动态计算建模,那么你是否应该使用JAX将取决于你的使用情况。如果你的大部分工作是在Python中使用大量的自定义代码,那么开始学习JAX以提高你的工作流程是值得的。或者如果你的大部分工作不是用Python,但你想建立某种基于模型/神经网络的混合系统,那么使用JAX可能是值得的。如果你的大部分工作不是用Python,或者你正在使用一些专门的软件进行研究(热力学、半导体等等),那么JAX可能并不是适合你的工具,除非你想从这些程序中导出数据进行某种自定义计算处理。如果你的兴趣领域更接近于物理/数学,并且包含了计算方法(动力系统、微分几何、统计物理),或者你的大部分工作是在例如Mathematica中进行的,那么坚持使用你正在使用的东西可能是值得的,特别是如果你有一个大型的自定义代码库。对于深度学习从业者来说,如果你刚开始接触深度学习,并考虑使用JAX,并且你对学习深度学习感兴趣,为了自己的教育,那么建议你使用JAX或PyTorch。如果你想自上而下地学习深度学习和/或有一些Python/软件经验,那么建议你从PyTorch开始。如果你想自下而上地学习深度学习和/或来自数学背景,你可能会觉得JAX很直观,应该试一试。在这种情况下,在进行任何大项目之前,请确保你了解如何使用JAX。如果你对深度学习感兴趣,并有可能为此而改变自己职业,那么你使用PyTorch或TensorFlow是更好的选择。虽然最好熟悉这两个框架,但你应该知道,TensorFlow被认为是「行业」框架,而PyTorch则更加「学术」。如果你是一个完全的初学者,没有数学或软件背景,但想学习深度学习,那么你就不会想使用JAX,从Keras开始是更好的选择。
不该使用JAX的情况
虽然JAX有可能极大地提高你的程序的性能,但也有几种情况下,是不适合使用JAX的。1. JAX仍然被官方认为是一个实验性框架,而不是一个完全成熟的Google产品,所以如果你正在考虑转移到JAX,需要慎重考虑。2. 在使用JAX时,调试的时间成本会更高,并且有很多bug仍然未被发现。对于那些没有牢固掌握函数式编程的人来说,使用JAX可能不值得。在开始将JAX用于正式的项目之前,请确保了解使用JAX的常见陷阱。3. JAX没有针对CPU计算进行优化。鉴于JAX是以「加速优先」的方式开发的,因此每个操作的调度并没有完全优化。正因为如此,在某些情况下,NumPy实际上可能比JAX更快,特别是对于小程序来说。4. JAX与Windows不兼容。目前在Windows上没有对JAX的支持。
Reddit网友讨论
文章的作者同时在Reddit上和网友进行了讨论,获得了近200个赞。网友lsaldyt表示他一直致力于用jax做序列模型(LSTM、NTM等),然后发现XLA的编译对于非常复杂的模型来说有点棘手。但他喜欢jax,一有机会就会向朋友宣传,但它绝对是一把双刃剑。他认为在几年内,JAX框架会变得更平滑,并且绝对会比其他框架更好。另外,很多基线是在pytorch中实现的,并且同时运行pytorch和jax相对容易。也有网友AlmightySnoo表示坚持使用PyTorch,只是因为它有一个更大的生态系统,一个先进的JIT机制(现在也适用于CPU),以及最重要的libtorch。你可以完全用C++来做训练和推理,而不需要用JAX/Python/XLA的组合来捣乱,并通过Tensorflow在C++中使用XLA。谷歌以后就会明白,训练过程中的低延迟也是非常重要的,在许多领域(尤其是量化金融,因为他们采用了libtorch,因为在这些用例中,你必须在每次使用时重新训练,不能简单地委托给Python。从长远来看,Facebook的这个设计决定显然是成功的。所以你认为,JAX到底是明日之星还是昙花一现?
参考资料:
https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/
https://www.reddit.com/r/MachineLearning/comments/st8b11/d_should_we_be_using_jax_in_2022/