Jax 和 Jaxlib 版本控制
原文:
jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html
为什么 jax
和 jaxlib
是独立的包?
我们将 JAX 发布为两个独立的 Python 轮子,即纯 Python 轮子 jax
和主要由 C++ 组成的轮子 jaxlib
,后者包含库,例如:
我们发布 jax
作为两个独立的 Python 轮子,即纯 Python 轮子 jax
和主要由 C++ 组成的轮子 jaxlib
,后者包含如下库:
此外,构建 jaxlib
不是廉价的,但我们希望能够在没有大量 CPU 的环境中迭代并运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建都简单地使用预构建的 jaxlib
,而不是在每个 PR 上重新构建 JAX 的 C++ 组件。
如我们将看到的,将 jax
和 jaxlib
分开发布也有一定成本,因为需要确保 jaxlib
的变更保持向后兼容的 API。然而,我们认为总体上,使得 Python 的变更变得简单是可取的,即使这会稍微增加 C++ 变更的难度。
jax
和 jaxlib
的版本如何确定?
概要:jax
和 jaxlib
在 JAX 源代码树中共享相同的版本号,但作为单独的 Python 包发布。安装时,jax
包版本必须大于或等于 jaxlib
的版本,并且 jaxlib
的版本必须大于或等于 jax
指定的最小 jaxlib
版本。
jax
和 jaxlib
发布版本号均为 x.y.z
,其中 x
是主版本号,y
是次版本号,z
是可选的补丁版本号。版本号必须遵循PEP 440。版本号比较是对整数元组的词典排序比较。
每个 jax
发布版本都有一个关联的最小 jaxlib
版本 mx.my.mz
。对于 jax
版本 x.y.z
,其最小 jaxlib
版本必须不大于 x.y.z
。
对于 jax
版本 x.y.z
和 jaxlib
版本 lx.ly.lz
兼容性要求如下:
jaxlib
版本(lx.ly.lz
)必须大于或等于最小的jaxlib
版本(mx.my.mz
)。jax
版本(x.y.z
)必须大于或等于jaxlib
版本(lx.ly.lz
)。
这些约束意味着发布需遵循以下规则:
- 可以随时单独发布
jax
而不更新jaxlib
。 - 如果发布新版
jaxlib
,必须同时发布一个jax
版本。
当前 jax
在导入时检查这些版本约束,而不是作为 Python 包版本约束来表达。 jax
在运行时检查 jaxlib
版本,而不是使用 pip
包版本约束,因为我们为各种硬件和软件版本(如 GPU、TPU 等)提供单独的 jaxlib
轮子。由于我们不知道哪种选择对任何给定用户来说是正确的,我们不希望 pip
自动为我们安装 jaxlib
包。
将来,我们希望将 jaxlib
的硬件特定部分分离成单独的插件,届时最低版本可以表达为 Python 包依赖性。目前,我们确实提供特定平台的额外要求,以安装兼容的 jaxlib
版本,例如 jax[cuda]
。
如何安全地对 jaxlib
的 API 进行更改?
jax
可能随时停止与旧版本的jaxlib
兼容,只要将最低jaxlib
版本升级到兼容版本即可。但请注意,即使是对于尚未发布的jax
版本,最低jaxlib
版本也必须是一个已发布的版本!这允许我们在持续集成构建中使用已发布的jaxlib
轮子,并允许 Python 开发者在不需要构建jaxlib
的情况下在 HEAD 上工作。
例如,要在jax
Python 代码中移除旧的向后兼容路径,只需提高最低jaxlib
版本然后删除兼容路径即可。jaxlib
可能会停止与低于其自身发布版本号的旧jax
发行版的兼容性。jax
强制执行的版本约束将禁止使用不兼容的jaxlib
。
例如,要使jaxlib
放弃一个旧的jax
版本使用的 Python 绑定 API,必须增加jaxlib
的次要或主要版本号。- 如果可能,应以向后兼容的方式对
jaxlib
进行更改。
通常,jaxlib
可以自由更改其 API,只要遵循jax
必须与至少两个jaxlib
版本兼容的规则。这意味着jax
必须始终与至少两个jaxlib
版本兼容,即最后一个发布版本和最新版本(实际上是下一个发布版本)。如果保持兼容性,这将更容易实现,尽管可以通过jax
的版本测试进行不兼容的更改;请参见下文。
例如,通常可以安全地向jaxlib
添加新功能,但是如果当前的jax
仍在使用它,删除现有功能或更改其签名则是不安全的。对jax
的更改必须在所有大于最低版本的jaxlib
发行版上运行或逐渐退化。
请注意,此处的兼容性规则仅适用于发布的jax
和jaxlib
版本。它们不适用于未发布的版本;也就是说,如果从未发布或没有发布的jax
版本使用该 API,则可以引入并删除jaxlib
中的 API。
jaxlib
的源代码布局是怎样的?
jaxlib
被分为两个主要的存储库,即jaxlib/
主 JAX 存储库的子目录和XLA 源代码树,位于 XLA 存储库内部。XLA 内部的 JAX 特定部分主要位于xla/python
子目录。
JAX 的 C++ 组件,如 Python 绑定和运行时组件,位于 XLA 树内部的原因部分是历史原因,部分是技术原因。
历史原因是最初xla/python
绑定被构想为通用 Python 绑定,可能与其他框架共享。实际上,这种情况越来越少见,xla/python
包含了许多特定于 JAX 的部分,并且可能会包含更多。因此,最好将xla/python
简单地视为 JAX 的一部分。
技术原因在于 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,可以将它们的 C++ 实现与 XLA 的 C++ API 进行原子更新。在 Python API 层面上,维护 Python API 的向后和向前兼容性要比维护 C++ API 更容易,因此xla/python
公开了 Python API 并负责在 Python 层面上维护向后兼容性。
jaxlib
使用 Bazel 从jax
存储库构建。来自 XLA 存储库的jaxlib
部分被合并到构建中 作为 Bazel 子模块。要在构建过程中更新使用的 XLA 版本,必须在 Bazel 的WORKSPACE
中更新固定的版本。这是根据需要手动完成的,但可以根据构建的需求进行覆盖。
在jax
和jaxlib
发布之间如何跨界修改?
jaxlib
版本是一个粗糙的工具:它只能让我们推断发布版本。
然而,由于jax
和jaxlib
代码分布在无法在单个更改中原子更新的存储库中,我们需要在比我们的发布周期更精细的粒度上管理兼容性。为了管理细粒度兼容性,我们有额外的版本控制,这与jaxlib
发布版本号独立。
我们在XLA 存储库中的xla_client.py
中维护了一个额外的版本号(_version
)。其理念是,这个版本号在xla/python
中与 JAX 的 C++部分一起定义,也可以作为jax._src.lib.xla_extension_version
被 JAX Python 访问,并且在每次对 XLA/Python 代码进行更改且这些更改对jax
的向后兼容性有影响时,都必须递增。JAX Python 代码可以利用这个版本号来维护向后兼容性,例如:
from jax._src.lib import xla_extension_version # 123 is the new version number for _version in xla_client.py if xla_extension_version >= 123: # Use new code path ... else: # Use old code path.
请注意,这个版本号是为了帮助管理开发中未发布代码的兼容性而存在的,与已发布版本号的约束额外。发布版本也必须遵循上述兼容性规则。
在 JAX 中序列化副作用
原文:
jax.readthedocs.io/en/latest/jep/10657-sequencing-effects.html
sharadmv@ May 9 2022
概述
当我们编写 JAX 代码时,通常可以假装我们在编写单线程、即时执行的 Python 代码,尽管在底层,JAX 及其运行时可能在后台异步执行。只要我们编写纯净(无副作用)的代码,这些性能优化通常对我们是不可见的,不会干扰我们的单线程心理模型。异步执行非常棒 — 我们可以获得高效、并行的代码,而无需考虑任何问题!
然而,在存在副作用的情况下,这种幻象开始破裂,我们心理模型的裂缝开始显现。具体来说,当我们考虑副作用发生的顺序时,这些差异就会显现出来。
在这篇设计说明中,我们探讨了 JAX 执行模型与副作用顺序之间的交互。我们还提供了一种强制执行“单线程”副作用顺序的方法。
背景
当我们编写以下 Python 代码时
def f(): print("hello") return 2 def g(): print("world") return 3 f() g()
我们期望 "hello"
在 "world"
之前被打印出来。这似乎是显而易见的,但考虑以下 JAX 代码:
@partial(jax.jit, device=<device 0>) def f(): return 2 @partial(jax.jit, device=<device 1>) def g(): return 3 f() g()
在许多情况下,JAX 将并行执行 f
和 g
,将计算分发到不同的线程 —— g
可能会在 f
之前执行。并行执行是一种很好的性能优化,特别是在设备间的复制成本昂贵时(详见异步调度说明了解更多详情)。然而,在实践中,我们通常不需要考虑异步调度,因为我们编写的是纯函数,只关心函数的输入和输出 —— 我们自然会在未来的值上阻塞。
但是,现在想象一下,我们有一个 jax.print
函数,可以在 JIT 编译的 JAX 函数内部工作(例如 host_callback.id_print
就是一个例子)。让我们回到之前的例子,但现在加入了打印输出。
@partial(jax.jit, device=<device 0>) def f(): jax.print("hello") return 2 @partial(jax.jit, device=<device 1>) def g(): jax.print("world") return 3 f() g()
由于异步调度的存在,我们实际上可以看到 "world"
在 "hello"
之前被打印出来。打印输出副作用的重新排序破坏了单线程执行模型的幻象。
另一个副作用可以“揭示”无序执行的示例是当我们编译 JAX 程序时。考虑以下 JAX 代码:
@jax.jit def f(x): jax.print("hello") jax.print("world") return x
尽管在 Python 中,我们先写了 "hello"
的打印,然后是 "world"
的打印,但是像 XLA 这样的编译器可以自由地重新排序它们,因为这两个打印之间没有显式的数据依赖关系。
动机
我们希望支持“有序”效果。所谓有序,意味着效果发生的顺序与我们在执行单线程 Python 程序时的顺序相同。这是我们的主要愿望。在存在显式并行性(如pmap
或用户线程)的情况下,我们不需要保持这种行为,但至少如果用户没有显式请求并行性,我们希望保持单线程顺序。
在深入讨论之前,让我们先退后一步,问问自己,如果我们为了性能而重新排序效果,这样做是否可以接受?反之,我们是否需要完全强制效果的顺序?在某些情况下,我们不需要排序。也许某些副作用不应该影响 JAX 程序的性能。然而,对于其他副作用,我们可能希望强制单线程程序顺序,以防止用户得到反直觉的行为。考虑一个日志效果。
@jax.jit def f(x, y): log_value(x) log_value(y) f(1, 2)
如果log
正在改变全局列表,我们可能期望在添加y
之前添加x
。为了更严格的效果,我们可能希望能够对效果进行排序。
强制有序效果
我们用来强制计算顺序的主要工具是数据依赖性。简单来说,如果函数g
的输入是函数f
的输出,那么必须先执行f
,再执行g
。
然而,我们可能会有像打印这样的副作用,这些副作用根本没有任何输入,因此我们无法简单地对它们进行排序。因此,我们使用令牌作为向计算中注入人为数据依赖性的手段。
什么是令牌?令牌只是可以在计算中穿插的虚拟值。通过在多个计算中穿插相同的令牌,我们强制它们按照特定顺序进行。让我们看看前面的打印示例,加入令牌后会是什么样子:
@jax.jit def f(token, x): token = jax.print(token, "hello") token = jax.print(token, "world") return token, x
如果我们重写jax.print
以接受并返回一个令牌,我们现在已经按顺序序列化了两个打印,因为第二个打印的输入依赖于第一个打印的输出。实际上,token
的实际值可以是任何东西,但我们会看到,这些令牌对用户来说是不可见的。
运行时令牌与编译器令牌
现在我们将开始讨论实现细节。实际上,我们需要两种不同类型的令牌来序列化效果:一种用于上述重新排序的每种源,我们需要运行时令牌来序列化异步调度的有副作用的计算,我们还需要编译器令牌来序列化计算内部的效果。
实际上,我们的计算将重写为以下形式:
@jax.jit def f(runtime_token, x): compiler_token = new_compiler_token() compiler_token = jax.print(compiler_token, "hello") compiler_token = jax.print(compiler_token, "world") return runtime_token, x
注意运行时令牌仅在 JIT 边界使用,而编译器令牌仅在编译后的代码中使用。编译器令牌是在“降级”过程中创建的(我们将 Python 代码转换为类似 HLO 或 StableHLO 的低级表示),但运行时令牌需要在 Python 中进行管理,因为它们在 JIT 化的函数中穿插输入和输出。
此外,请注意运行时令牌与编译器令牌之间是“断开”的,这意味着它们之间没有数据依赖关系。这可能是危险的,因为我们会失去两个调度函数调用体之间的数据依赖性。然而,如果我们假设“严格执行”——即一个调度函数只有在其所有输入准备就绪且所有输出同时准备就绪时才会开始执行——我们可以安全地创建一个新的编译器令牌,并返回一个不依赖于输出的运行时令牌。
管理运行时令牌
为了代表用户管理运行时令牌,我们需要插入到 JAX 的调度机制中。每当我们调用 JIT 编译的函数时,我们最终会得到一个看起来像这样的函数:
def _execute(compiled_computation, *args): outputs = compiled_computation.execute(*args) return outputs
此时我们需要"注入"运行时令牌到计算中,并从计算的输出中"提取"它们:
def _execute(compiled_computation, *args): runtime_token = get_runtime_token() # Grab global token runtime_token, *outputs = compiled_computation.execute(runtime_token, *args) update_runtime_token(runtime_token) # Update global token return outputs
runtime_token
究竟是什么?嗯,我们需要能够将其传递给compiled_computation
,这意味着它需要是某种数组(目前来说,由于在编译的 JAX 代码内外没有共享的令牌表示,我们可以使用一个(0,)
形状的数组来最小化开销)。
我们还需要考虑多设备使用情况,例如第一个示例中,我们首先在设备 0 上调用 JIT 编译的函数,然后在设备 1 上调用另一个函数。在这种情况下,我们还需要将第一个计算返回的运行时令牌(位于设备 0 上)复制到设备 1,以便将其传递给第二个计算。如果两个后续计算共享相同的设备,则此复制是不必要的。
添加编译器令牌
当我们将 Python 代码降级为 HLO 或 StableHLO 时,我们需要在计算开始时创建一个令牌,并确保在需要对顺序进行排序的副作用计算时可用。副作用计算将该令牌作为输入,并将其作为输出返回。
实现此令牌线程涉及升级 JAX 降级机制以自动进行此类记账。主要挑战涉及处理像调用原语和控制流原语这样的高阶原语。在本设计说明中,我们不会详细讨论如何处理这些挑战。
阻塞输出令牌
为运行时和编译器令牌增加支持以进行副作用计算序列化是很重要的,但令牌还有另一个微妙的用例,即在副作用计算上阻塞。即使我们不希望副作用计算是有序的,我们可能仍然希望等待其完成。目前我们有jax.block_until_ready
,它会等待直到未来的值准备就绪。然而,对于副作用计算,我们可能有一些没有返回值但仍在执行副作用的函数。以这里的简单示例为例:
@jax.jit def f(): jax.print("hello world") return f() # Executed asynchronously
这个编译后的计算不接受任何显式输入,也没有显式输出。如果它是一个有序的打印效果,我们可以阻塞返回的运行时令牌,但是当这是一个无序计算时,我们不执行任何令牌线程。当我们没有输出值来调用block_until_ready
时,我们如何等待f()
执行结束呢?嗯,我们可以应用相同的令牌策略,除了我们只返回运行时令牌而不将它们作为输入。这将给我们一个可以阻塞的值,该值仅在f()
执行完成后才会准备好。我们将这些令牌称为输出令牌。我们最终得到了如下所示的函数:
@jax.jit def f(): jax.print("hello world") return new_runtime_token() f() # Executed asynchronously
在幕后,我们将以与管理运行时令牌相同的方式来管理输出令牌,但提供一种方法让用户在当前一组输出令牌上阻塞。与运行时令牌不同,输出令牌需要是特定于设备的。考虑单设备使用情况:
@jax.jit def f(): jax.print("hello") @jax.jit def g(): jax.print("world") f() g()
由于f()
和g()
在同一设备上执行,阻塞g()
的输出令牌有效地阻塞了f()
,因为(目前为止!),JAX 运行时不会交错执行在同一设备上执行的计算。当然,如果情况改变,我们将不得不重新审视整个设计。
然而,考虑两个设备使用情况:
@partial(jax.jit, device=<device 0>) def f(): jax.print("hello") @partial(jax.jit, device=<device 1>) def g(): jax.print("world") f() g()
这里我们不想显式地序列f()
和g()
,但是希望等待它们都完成。我们需要一个f()
的输出令牌和一个g()
的输出令牌,并且我们将阻塞在这两个令牌上:
@partial(jax.jit, device=<device 0>) def f(): jax.print("hello") return new_runtime_token() @partial(jax.jit, device=<device 1>) def g(): jax.print("world") return new_runtime_token() t0 = f() t1 = g() block_until_ready((t0, t1))
因此,我们需要每个设备的输出令牌,这样我们就可以避免在不同设备上对计算进行排序,同时可以阻塞具有副作用的计算。我们最终得到了以下(大致)对 JAX 调度机制的更改:
def _execute(compiled_computation, *args): output_token, *outputs = compiled_computation.execute(runtime_token, *args) update_output_token(output_token, compiled_computation.device) return outputs
我们还需要暴露一个函数来阻塞输出令牌:
def effects_barrier(): output_token.block_until_ready()
注意,阻塞输出令牌可能不太常见,因为大多数 JAX 计算将返回一个值来阻塞。然而,输出令牌对于测试和分析非常有用,并且支持它们是很好的,这样我们就有了一个一致且有条理的效果系统。
更多细节
- 所有上述的令牌管理基础设施将是线程本地的。这意味着每个用户线程将有自己独立的运行时令牌流。排序仅在用户线程级别上承诺。
- 在实践中,我们每个效果有一个运行时令牌。不同实例的该效果将被排序。这是为了避免对彼此可能没有任何关系的具有影响力的计算进行排序。从技术上讲,这与我们最初的目标相矛盾,即强制执行单线程 Python 程序的顺序,但这是一个可以通过同时具有“效果”特定令牌和“全局”令牌来调节的折衷方案。
JAX 中文文档(十二)(2)https://developer.aliyun.com/article/1559714