JAX 中文文档(四)(3)https://developer.aliyun.com/article/1559795
提前降低和编译
JAX 提供了几种转换,如jax.jit
和jax.pmap
,返回一个编译并在加速器或 CPU 上运行的函数。正如 JIT 缩写所示,所有编译都是即时执行的。
有些情况需要进行提前(AOT)编译。当你希望在执行之前完全编译,或者希望控制编译过程的不同部分何时发生时,JAX 为您提供了一些选项。
首先,让我们回顾一下编译的阶段。假设f
是由jax.jit()
输出的函数/可调用对象,例如对于某个输入可调用对象F
,f = jax.jit(F)
。当它用参数调用时,例如f(x, y)
,其中x
和y
是数组,JAX 按顺序执行以下操作:
- Stage out原始 Python 可调用
F
的特殊版本到内部表示。专门化反映了F
对从参数x
和y
的属性推断出的输入类型的限制(通常是它们的形状和元素类型)。 - Lower这种特殊的阶段计算到 XLA 编译器的输入语言 StableHLO。
- Compile降低的 HLO 程序以生成针对目标设备(CPU、GPU 或 TPU)的优化可执行文件。
- Execute使用数组
x
和y
作为参数执行编译后的可执行文件。
JAX 的 AOT API 允许您直接控制步骤#2、#3 和#4(但不包括#1),以及沿途的一些其他功能。例如:
>>> import jax >>> def f(x, y): return 2 * x + y >>> x, y = 3, 4 >>> lowered = jax.jit(f).lower(x, y) >>> # Print lowered HLO >>> print(lowered.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %c = stablehlo.constant dense<2> : tensor<i32> %0 = stablehlo.multiply %c, %arg0 : tensor<i32> %1 = stablehlo.add %0, %arg1 : tensor<i32> return %1 : tensor<i32> } } >>> compiled = lowered.compile() >>> # Query for cost analysis, print FLOP estimate >>> compiled.cost_analysis()[0]['flops'] 2.0 >>> # Execute the compiled function! >>> compiled(x, y) Array(10, dtype=int32, weak_type=True)
请注意,降低的对象只能在它们被降低的同一进程中使用。有关导出用例,请参阅导出和序列化 API。
有关降低和编译函数提供的功能的更多详细信息,请参见jax.stages
文档。
在上面的jax.jit
的位置,您还可以lower(...)``jax.pmap()
的结果,以及pjit
和xmap
(分别来自jax.experimental.pjit
和jax.experimental.maps
)。在每种情况下,您也可以类似地compile()
结果。
所有jit
的可选参数——如static_argnums
——在相应的降低、编译和执行中都得到尊重。同样适用于pmap
、pjit
和xmap
。
在上述示例中,我们可以将lower
的参数替换为具有shape
和dtype
属性的任何对象:
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32')) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y) Array(10, dtype=int32)
更一般地说,lower
只需其参数结构上提供 JAX 必须了解的内容进行专门化和降低。对于像上面的典型数组参数,这意味着shape
和dtype
字段。相比之下,对于静态参数,JAX 需要实际的数组值(下面会详细说明)。
使用与其降低不兼容的参数调用 AOT 编译函数会引发错误:
>>> x_1d = y_1d = jnp.arange(3) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) ... Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with int32[3] Argument 'y' compiled with int32[] and called with int32[3] >>> x_f = y_f = jnp.float32(72.) >>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) ... Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: Argument 'x' compiled with int32[] and called with float32[] Argument 'y' compiled with int32[] and called with float32[]
与此相关的是,AOT 编译函数不能通过 JAX 的即时转换(如jax.jit
、jax.grad()
和jax.vmap()
)进行转换。
使用静态参数进行降低
使用静态参数进行降级强调了传递给jax.jit
的选项、传递给lower
的参数以及调用生成的编译函数所需的参数之间的交互。继续我们上面的示例:
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8) >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) { %c = stablehlo.constant dense<14> : tensor<i32> %0 = stablehlo.add %c, %arg0 : tensor<i32> return %0 : tensor<i32> } } >>> lowered_with_x.compile()(5) Array(19, dtype=int32, weak_type=True)
lower
的结果不能直接序列化以供在不同进程中使用。有关此目的的额外 API,请参见导出和序列化。
注意,这里的lower
像往常一样接受两个参数,但随后生成的编译函数仅接受剩余的非静态第二个参数。静态的第一个参数(值为 7)在降级时被视为常量,并内置到降级计算中,其中可能会与其他常量一起折叠。在这种情况下,它的乘以 2 被简化为常量 14。
尽管上面lower
的第二个参数可以被一个空的形状/数据类型结构替换,但静态的第一个参数必须是一个具体的值。否则,降级将会出错:
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) Traceback (most recent call last): TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct' >>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5) Array(25, dtype=int32)
AOT 编译的函数不能被转换
编译函数专门针对一组特定的参数“类型”,例如我们正在运行的示例中具有特定形状和元素类型的数组。从 JAX 的内部角度来看,诸如jax.vmap()
之类的转换会以一种方式改变函数的类型签名,使得已编译的类型签名失效。作为一项政策,JAX 简单地禁止已编译的函数参与转换。示例:
>>> def g(x): ... assert x.shape == (3, 2) ... return x @ jnp.ones(2) >>> def make_z(*shape): ... return jnp.arange(np.prod(shape)).reshape(shape) >>> z, zs = make_z(3, 2), make_z(4, 3, 2) >>> g_jit = jax.jit(g) >>> g_aot = jax.jit(g).lower(z).compile() >>> jax.vmap(g_jit)(zs) Array([[ 1., 5., 9.], [13., 17., 21.], [25., 29., 33.], [37., 41., 45.]], dtype=float32) >>> jax.vmap(g_aot)(zs) Traceback (most recent call last): TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>
当g_aot
参与自动微分(例如jax.grad()
)时也会引发类似的错误。为了一致性,jax.jit
的转换也被禁止,尽管jit
并没有实质性地修改其参数的类型签名。
调试信息和分析,在可用时
除了主要的 AOT 功能(分离和显式的降级、编译和执行),JAX 的各种 AOT 阶段还提供一些额外的功能,以帮助调试和收集编译器反馈。
例如,正如上面的初始示例所示,降级函数通常提供文本表示。编译函数也是如此,并且还提供来自编译器的成本和内存分析。所有这些都通过jax.stages.Lowered
和jax.stages.Compiled
对象上的方法提供(例如,上面的lowered.as_text()
和compiled.cost_analysis()
)。
这些方法旨在帮助手动检查和调试,而不是作为可靠的可编程 API。它们的可用性和输出因编译器、平台和运行时而异。这导致了两个重要的注意事项:
- 如果某些功能在 JAX 当前的后端上不可用,则其方法将返回某些微不足道的东西(类似于
False
)。例如,如果支持 JAX 的编译器不提供成本分析,则compiled.cost_analysis()
将为None
。 - 如果某些功能可用,则对应方法提供的内容仍然有非常有限的保证。返回值在 JAX 的配置、后端/平台、版本或甚至方法的调用之间,在类型、结构或值上不需要保持一致。JAX 无法保证
compiled.cost_analysis()
在一天的输出将会在随后的一天保持相同。
如果有疑问,请参阅 jax.stages
的包 API 文档。
检查暂停的计算
此笔记顶部列表中的第一个阶段提到专业化和分阶段,之后是降低。JAX 内部对其参数类型专门化的函数的概念,并非始终在内存中具体化为数据结构。要显式构建 JAX 在内部Jaxpr 中间语言中函数专门化的视图,请参见 jax.make_jaxpr()
。
导出和序列化
指南
- 导出和序列化分阶段计算
- 支持逆向模式自动微分(AD)
- 兼容性保证
- 跨平台和多平台导出
- 形状多态导出
- 设备多态导出
- 调用约定版本
- 从 jax.experimental.export 的迁移指南
- 形状多态性
- 形状多态性的正确性
- 使用维度变量进行计算
- 与 TensorFlow 的互操作性
导出和序列化分离计算
提前降级和编译的 API 生成的对象可用于调试或在同一进程中进行编译和执行。有时候,您希望将降级后的 JAX 函数序列化,以便在稍后的时间在单独的进程中进行编译和执行。这将允许您:
- 在另一个进程或机器上编译并执行该函数,而无需访问 JAX 程序,并且无需重复分离和降低级别,例如在推断系统中。
- 跟踪和降低一个在没有访问您希望稍后编译和执行该函数的加速器的机器上的函数。
- 存档 JAX 函数的快照,例如以便稍后能够重现您的结果。**注意:**请查看此用例的兼容性保证。
这里有一个例子:
>>> import re >>> import numpy as np >>> import jax >>> from jax import export >>> def f(x): return 2 * x * x >>> exported: export.Exported = export.export(jax.jit(f))( ... jax.ShapeDtypeStruct((), np.float32)) >>> # You can inspect the Exported object >>> exported.fun_name 'f' >>> exported.in_avals (ShapedArray(float32[]),) >>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) { >>> # And you can serialize the Exported to a bytearray. >>> serialized: bytearray = exported.serialize() >>> # The serialized function can later be rehydrated and called from >>> # another JAX computation, possibly in another process. >>> rehydrated_exp: export.Exported = export.deserialize(serialized) >>> rehydrated_exp.in_avals (ShapedArray(float32[]),) >>> def callee(y): ... return 3. * rehydrated_exp.call(y * 4.) >>> callee(1.) Array(96., dtype=float32)
序列化分为两个阶段:
- 导出以生成一个包含降级函数的 StableHLO 和调用它所需的元数据的
jax.export.Exported
对象。我们计划添加代码以从 TensorFlow 生成Exported
对象,并使用来自 TensorFlow 和 PyTorch 的Exported
对象。 - 使用 flatbuffers 格式的字节数组进行实际序列化。有关与 TensorFlow 的交互操作的替代序列化,请参阅与 TensorFlow 的互操作性。
支持反向模式 AD
序列化可以选择支持高阶反向模式 AD。这是通过将原始函数的 jax.vjp()
与原始函数一起序列化,直到用户指定的顺序(默认为 0,意味着重新水化的函数无法区分)完成的:
>>> import jax >>> from jax import export >>> from typing import Callable >>> def f(x): return 7 * x * x * x >>> # Serialize 3 levels of VJP along with the primal function >>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3) >>> rehydrated_f: Callable = export.deserialize(blob).call >>> rehydrated_f(0.1) # 7 * 0.1³ Array(0.007, dtype=float32) >>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1² Array(0.21000001, dtype=float32) >>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1 Array(4.2, dtype=float32) >>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2 Array(42., dtype=float32) >>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) Traceback (most recent call last): ValueError: No VJP is available
请注意,在序列化时计算 VJP 函数是惰性的,当 JAX 程序仍然可用时。这意味着它遵守 JAX VJP 的所有特性,例如 jax.custom_vjp()
和 jax.remat()
。
请注意,重新水化的函数不支持任何其他转换,例如前向模式 AD(jvp)或 jax.vmap()
。
JAX 中文文档(四)(5)https://developer.aliyun.com/article/1559797