JAX 中文文档(十五)(2)https://developer.aliyun.com/article/1559768
jax.experimental.host_callback 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.host_callback.html
在 JAX 加速器代码中调用 Python 函数的原语。
警告
自 2024 年 3 月 20 日起,host_callback API 已弃用。功能已被 新的 JAX 外部回调 所取代。请参阅 google/jax#20385。
此模块介绍了主机回调函数 call()
、id_tap()
和 id_print()
,它们将其参数从设备发送到主机,并在主机上调用用户定义的 Python 函数,可选地将结果返回到设备计算中。
我们展示了下面如何使用这些函数。我们从 call()
开始,并讨论从 JAX 调用 CPU 上任意 Python 函数的示例,例如使用 NumPy CPU 自定义核函数。然后我们展示了使用 id_tap()
和 id_print()
,它们的限制是不能将主机返回值传回设备。这些原语通常更快,因为它们与设备代码异步执行。特别是它们可用于连接到和调试 JAX 代码。
使用 call()
调用主机函数并将结果返回给设备
使用 call()
调用主机上的计算并将 NumPy 数组返回给设备上的计算。主机计算在以下情况下非常有用,例如当设备计算需要一些需要在主机上进行 I/O 的数据,或者它需要一个在主机上可用但不希望在 JAX 中编码的库时。例如,在 JAX 中一般矩阵的特征值分解在 TPU 上不起作用。我们可以从任何 JAX 加速计算中调用 Numpy 实现,使用主机计算:
# This function runs on the host def host_eig(m: np.ndarray) -> np.ndarray: return np.linalg.eigvals(m) # This function is used in JAX def device_fun(m): # We send "m" to the host, asking it to call "host_eig" and return the result. # We have to specify the result shape and dtype, either in the form of an # example return value or any object that has `shape` and `dtype` attributes, # e.g., a NumPy array or a `jax.ShapeDtypeStruct`. return hcb.call(host_eig, m, # Given an input of shape (..., d, d), eig output has shape (..., d) result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
call()
函数和 Python 主机函数都接受一个参数并返回一个结果,但这些可以是 pytrees。注意,我们必须告诉 call()
从主机调用中期望的形状和 dtype,使用 result_shape
关键字参数。这很重要,因为设备代码是按照这个期望进行编译的。如果实际调用产生不同的结果形状,运行时会引发错误。通常,这样的错误以及主机计算引发的异常可能很难调试。请参见下面的调试部分。这对 call()
是一个问题,但对于 id_tap()
不是,因为对于后者,设备代码不期望返回值。
call()
API 可以在 jit 或 pmap 计算内部使用,或在 cond/scan/while 控制流内部使用。当在 jax.pmap()
内部使用时,将从每个参与设备中分别调用主机:
def host_sin(x, *, device): # The ``device`` argument is passed due to ``call_with_device=True`` below. print(f"Invoking host_sin with {x.shape} on {device}") return np.sin(x) # Use pmap to run the computation on two devices jax.pmap(lambda x: hcb.call(host_sin, x, result_shape=x, # Ask that the `host_sin` function be passed `device=dev` call_with_device=True))( np.ones((2, 4), dtype=np.float32)) # prints (in arbitrary order) # Invoking host_sin with (4,) on cpu:0 # Invoking host_sin with (4,) on cpu:1
请注意,call()
不支持任何 JAX 转换,但如下所示,可以利用现有的支持来自定义 JAX 中的导数规则。
使用id_tap()
在主机上调用 Python 函数,不返回任何值。
id_tap()
和id_print()
是call()
的特殊情况,当您只希望 Python 回调的副作用时。这些函数的优点是一旦参数已发送到主机,设备计算可以继续进行,而无需等待 Python 回调返回。对于id_tap()
,您可以指定要调用的 Python 回调函数,而id_print()
则使用一个内置回调,在主机的标准输出中打印参数。传递给id_tap()
的 Python 函数接受两个位置参数(从设备计算中获取的值以及一个transforms
元组,如下所述)。可选地,该函数可以通过关键字参数device
传递设备从中获取的设备。
几个示例:
def host_func(arg, transforms): ...do something with arg... # calls host_func(2x, []) on host id_tap(host_func, 2 * x) # calls host_func((2x, 3x), []) id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree # calls host_func(2x, [], device=jax.devices()[0]) id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap # calls host_func(2x, [], what='activation') id_tap(functools.partial(host_func, what='activation'), 2 * x) # calls host_func(dict(x=x, y=y), what='data') id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
所有上述示例都可以改用id_print()
,只是id_print()
会在主机上打印位置参数,以及任何额外的关键字参数和自动关键字参数transforms
之间的区别。
使用barrier_wait()
等待所有回调函数执行结束。
如果你的 Python 回调函数有副作用,可能需要等到计算完成,以确保副作用已被观察到。你可以使用barrier_wait()
函数来实现这一目的:
accumulator = [] def host_log(arg, transforms): # We just record the arguments in a list accumulator.append(arg) def device_fun(x): id_tap(host_log, x) id_tap(host_log, 2. * x) jax.jit(device_fun)(1.) jax.jit(device_fun)(1.) # At this point, we have started two computations, each with two # taps, but they may not have yet executed. barrier_wait() # Now we know that all the computations started before `barrier_wait` # on all devices, have finished, and all the callbacks have finished # executing.
请注意,barrier_wait()
将在jax.local_devices()
的每个设备上启动一个微小的计算,并等待所有这些计算的结果被接收。
一个替代方案是使用barrier_wait()
仅等待计算结束,如果所有回调都是call()
的话:
accumulator = p[] def host_log(arg): # We just record the arguments in a list accumulator.append(arg) return 0. # return something def device_fun(c): y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32)) z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32)) return y + z # return something that uses both results res1 = jax.jit(device_fun)(1.) res2 = jax.jit(device_fun)(1.) res1.block_until_ready() res2.block_until_ready()
并行化转换下的行为
在存在jax.pmap()
的情况下,代码将在多个设备上运行,并且每个设备将独立地执行其值。建议为id_print()
或id_tap()
使用tap_with_device
选项可能会有所帮助,以便查看哪个设备发送了哪些数据:
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.]) # device=cpu:0 what=x,x²: (3., 9.) # from the first device # device=cpu:1 what=x,x²: (4., 16.) # from the second device
使用jax.pmap()
和多个主机上的多个设备时,每个主机将从其所有本地设备接收回调,带有与每个设备切片对应的操作数。对于call()
,回调必须仅向每个设备返回与相应设备相关的结果切片。
当使用实验性的pjit.pjit()
时,代码将在多个设备上运行,并在输入的不同分片上。当前主机回调的实现将确保单个设备将收集并输出整个操作数,在单个回调中。回调函数应返回整个数组,然后将其发送到发出输出的同一设备的单个进料中。然后,此设备负责将所需的分片发送到其他设备:
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]): pjit.pjit(power3, in_shardings=(P("d"),), out_shardings=(P("d"),))(np.array([3., 4.])) # device=TPU:0 what=x,x²: ( [3., 4.], # [9., 16.] )
请注意,在一个设备上收集操作数可能会导致内存不足,如果操作数分布在多个设备上则情况类似。
当在多个设备上的多个主机上使用 pjit.pjit()
时,仅设备 0(相对于网格)上的主机将接收回调,其操作数来自所有参与设备上的所有主机。对于 call()
,回调必须返回所有设备上所有主机的整个数组。
在 JAX 自动微分转换下的行为
在 JAX 自动微分转换下使用时,主机回调函数仅处理原始值。考虑以下示例:
def power3(x): y = x * x # Print both 'x' and 'x²'. Must pack as a tuple. hcb.id_print((x, y), what="x,x²") return y * x power3(3.) # what: x,x² : (3., 9.)
(您可以在 host_callback_test.HostCallbackTapTest.test_tap_transforms
中查看这些示例的测试。)
当在 jax.jvp()
下使用时,仅会有一个回调处理原始值:
jax.jvp(power3, (3.,), (0.1,)) # what: x,x² : (3., 9.)
类似地,对于 jax.grad()
,我们仅从前向计算中得到一个回调:
jax.grad(power3)(3.) # what: x,x² : (3., 9.)
如果您想在 jax.jvp()
中对切线进行回调处理,可以使用 custom_jvp
。例如,您可以定义一个除了其 custom_jvp
会打印切线之外无趣的函数:
@jax.custom_jvp def print_tangents(arg): return None @print_tangents.defjvp def print_tangents_jvp(primals, tangents): arg_dot, = tangents hcb.id_print(arg_dot, what="tangents") return primals, tangents
然后,您可以在想要触发切线的位置使用此函数:
def power3_with_tangents(x): y = x * x # Print both 'x' and 'x²'. Must pack as a tuple. hcb.id_print((x, y), what="x,x²") print_tangents((x, y)) return y * x jax.jvp(power3_with_tangents, (3.,), (0.1,)) # what: x,x² : (3., 9.) # what: tangents : (0.1, 0.6)
您可以在 jax.grad()
中做类似的事情来处理余切。这时,您必须小心使用在其余计算中需要的余切值。因此,我们使 print_cotangents
返回其参数:
@jax.custom_vjp def print_cotangents(arg): # Must return the argument for which we want the cotangent. return arg # f_fwd: a -> (b, residual) def print_cotangents_fwd(arg): return print_cotangents(arg), None # f_bwd: (residual, CT b) -> [CT a] def print_cotangents_bwd(residual, ct_b): hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream) return ct_b, print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd) def power3_with_cotangents(x): y = x * x # Print both 'x' and 'x²'. Must pack as a tuple. hcb.id_print((x, y), what="x,x²", output_stream=testing_stream) (x1, y1) = print_cotangents((x, y)) # Must use the output of print_cotangents return y1 * x1 jax.grad(power3_with_cotangents)(3.) # what: x,x² : (3., 9.) # what: cotangents : (9., 3.)
如果您使用 ad_checkpoint.checkpoint()
来重新生成反向传播的残差,则原始计算中的回调将被调用两次:
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.) # what: x,x² : (3., 9.) # what: x,x² : (27., 729.) # what: x,x² : (3., 9.)
这些回调依次是:内部 power3
的原始计算,外部 power3
的原始计算,以及内部 power3
的残差重新生成。
在 jax.vmap
下的行为
主机回调函数 id_print()
和 id_tap()
支持矢量化转换 jax.vmap()
。
对于 jax.vmap()
,回调的参数是批量处理的,并且回调函数会传递一个特殊的 transforms
,其中包含转换描述符列表,格式为 ("batch", {"batch_dims": ...})
,其中 ...
表示被触发值的批处理维度(每个参数一个条目,None
表示广播的参数)。
jax.vmap(power3)(np.array([2., 3.]))
# transforms: [(‘batch’, {‘batch_dims’: (0, 0)})] what: x,x² : ([2., 3.], [4., 9.])
请参阅 id_tap()
、id_print()
和 call()
的文档。
更多用法示例,请参阅 tests/host_callback_test.py
。
使用 call()
调用 TensorFlow 函数,支持反向模式自动微分
主机计算的另一个可能用途是调用为另一个框架编写的库,如 TensorFlow。在这种情况下,通过使用 jax.custom_vjp()
机制来支持主机回调的 JAX 自动微分变得有趣。
一旦理解了 JAX 自定义 VJP 和 TensorFlow autodiff 机制,这就相对容易做到。可以在 host_callback_to_tf_test.py 中的 call_tf_full_ad
函数中看到如何实现这一点。该示例还支持任意高阶微分。
请注意,如果只想从 JAX 调用 TensorFlow 函数,也可以使用 jax2tf.call_tf function。
使用 call()
在另一个设备上调用 JAX 函数,支持反向模式自动微分
我们可以使用主机计算来调用另一个设备上的 JAX 计算,并不奇怪。参数从加速器发送到主机,然后发送到将运行 JAX 主机计算的外部设备,然后将结果发送回原始加速器。
可以在 host_callback_test.py 中的 call_jax_other_device function
中看到如何实现这一点。
低级细节和调试
主机回调函数将按照在设备上执行发送操作的顺序执行。
多个设备的主机回调函数可能会交错执行。设备数据由 JAX 运行时管理的单独线程接收(每个设备一个线程)。运行时维护一个可配置大小的缓冲区(参见标志 --jax_host_callback_max_queue_byte_size
)。当缓冲区满时,所有接收线程将被暂停,最终暂停设备上的计算。对于更多关于 outfeed 接收器运行时机制的细节,请参阅 runtime code。
要等待已经启动在设备上的计算的所有数据到达并被处理,可以使用 barrier_wait()
。
用户定义的回调函数抛出的异常以及它们的堆栈跟踪都会被记录,但接收线程不会停止。相反,最后一个异常被记录,并且随后的 barrier_wait()
将在任何一个 tap 函数中发生异常时引发 CallbackException
。此异常将包含最后异常的文本和堆栈跟踪。
对于必须将结果返回给调用原点设备的回调函数(如call()
),存在进一步的复杂性。这在 CPU/GPU 设备与 TPU 设备上处理方式不同。
在 CPU/GPU 设备上,为了避免设备计算因等待永远不会到达的结果而陷入困境,在处理回调过程中出现任何错误(无论是由用户代码自身引发还是由于返回值与期望返回形状不匹配而引发),我们会向设备发送一个形状为 int8[12345]
的“虚假”结果。这将导致设备计算中止,因为接收到的数据与其预期的数据不同。在 CPU 上,运行时将崩溃并显示特定的错误消息:
` 检查失败:buffer->length() == buffer_length (12345 vs. ...) `
在 GPU 上,这种失败会更加用户友好,并将其作为以下形式暴露给 Python 程序:
` RET_CHECK 失败 ... 输入源缓冲区形状为 s8[12345] 不匹配 ... `
要调试这些消息的根本原因,请参阅调试部分。
在 TPU 设备上,目前没有对输入源进行形状检查,因此我们采取更安全的方式,在出现错误时不发送此虚假结果。这意味着计算将会挂起,且不会引发异常(但回调函数中的任何异常仍将出现在日志中)。
当前实现使用 XLA 提供的出料机制。该机制本身在某种程度上相当原始,因为接收器必须准确知道每个传入数据包的形状和预期的数据包数量。这使得它在同一计算中难以用于多种数据类型,并且在非常量迭代次数的条件或循环中几乎不可能使用。此外,直接使用出料机制的代码无法由 JAX 进行转换。所有这些限制都通过主机回调函数得到解决。此处引入的 tapping API 可以轻松地用于多种目的共享出料机制,同时支持所有转换。
注意,在使用主机回调函数后,您不能直接使用 lax.outfeed。如果以后需要使用 lax.outfeed,则可能需要 stop_outfeed_receiver()
。
由于实际调用您的回调函数是从 C++ 接收器进行的,因此调试这些调用可能会很困难。特别是,堆栈跟踪不会包含调用代码。您可以使用标志 jax_host_callback_inline
(或环境变量 JAX_HOST_CALLBACK_INLINE
)确保回调函数的调用是内联的。这仅在调用位于非常量迭代次数的阶段上下文之外时有效(例如 jit()
或控制流原语)。
C++ 接收器 会在首次调用 id_tap()
时自动启动。为了正确停止它,在启动时注册了一个 atexit
处理程序,以带有日志名称“at_exit”调用 barrier_wait()
。
有几个环境变量可用于启用 C++ outfeed 接收器后端的日志记录(接收器后端)。
TF_CPP_MIN_LOG_LEVEL=0
:将 INFO 日志打开,适用于以下所有内容。TF_CPP_MIN_VLOG_LEVEL=3
:将所有 VLOG 日志级别为 3 的行为设为 INFO 日志。这可能有些过多,但你将看到哪些模块记录了相关信息,然后你可以选择从哪些模块记录日志。TF_CPP_VMODULE==3
(模块名可以是 C++ 或 Python,不带扩展名)。
你还应该使用 --verbosity=2
标志,这样你就可以看到 Python 的日志。
例如,你可以尝试在 host_callback
模块中启用日志记录:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
如果你想在更低级别的实现模块中启用日志记录,请尝试:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
(对于 bazel 测试,请使用 –test_arg=–vmodule=…
仍需完成:
- 更多性能测试。
- 探索在 TPU 上进行外部编译实现。
- 探索在 CPU 和 GPU 上使用 XLA CustomCall 进行实现。
API
id_tap (tap_func, arg, *[, result, …]) |
主机回调 tap 原语,类似于带有 tap_func 调用的恒等函数。 |
id_print (arg, *[, result, tap_with_device, …]) |
类似于 id_tap() ,带有打印 tap 函数。 |
call (callback_func, arg, *[, result_shape, …]) |
调用主机,并期望得到结果。 |
barrier_wait ([logging_name]) |
阻塞调用线程,直到所有当前 outfeed 处理完毕。 |
CallbackException |
表示某些回调函数发生异常。 |
JAX 中文文档(十五)(4)https://developer.aliyun.com/article/1559771