JAX 中文文档(四)(2)

简介: JAX 中文文档(四)

JAX 中文文档(四)(1)https://developer.aliyun.com/article/1559792


JAX 中的外部回调

原文:jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

本指南概述了各种回调函数的用途,这些函数允许 JAX 运行时在主机上执行 Python 代码,即使在jitvmapgrad或其他转换的情况下也是如此。

为什么需要回调?

回调例程是在运行时执行主机端代码的一种方式。举个简单的例子,假设您想在计算过程中打印某个变量的。使用简单的 Python print 语句,如下所示:

import jax
@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2
result = f(2) 
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

打印的不是运行时值,而是跟踪时的抽象值(如果您对在 JAX 中的追踪不熟悉,可以在How To Think In JAX找到一个很好的入门教程)。

要在运行时打印值,我们需要一个回调,例如jax.debug.print

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2
result = f(2) 
intermediate value: 3 

通过将由y表示的运行时值传递回主机进程,主机可以打印值。

回调的种类

在早期版本的 JAX 中,只有一种类型的回调可用,即jax.experimental.host_callback中实现的。host_callback例程存在一些缺陷,现已弃用,而现在推荐使用为不同情况设计的几个回调:

  • jax.pure_callback(): 适用于纯函数,即没有副作用的函数。
  • jax.experimental.io_callback(): 适用于不纯的函数,例如读取或写入磁盘数据的函数。
  • jax.debug.callback(): 适用于应反映编译器执行行为的函数。

(我们上面使用的jax.debug.print()函数是jax.debug.callback()的一个包装器)。

从用户角度来看,这三种回调的区别主要在于它们允许什么样的转换和编译器优化。

回调函数 支持返回值 jit vmap grad scan/while_loop 保证执行
jax.pure_callback ❌¹
jax.experimental.io_callback ✅/❌² ✅³
jax.debug.callback

¹ jax.pure_callback可以与custom_jvp一起使用,使其与自动微分兼容。

² 当ordered=False时,jax.experimental.io_callbackvmap兼容。

³ 注意vmapscan/while_loopio_callback具有复杂的语义,并且其行为可能在未来的版本中更改。

探索jax.pure_callback

通常情况下,jax.pure_callback是您在想要执行纯函数的主机端时应使用的回调函数:即没有副作用的函数(如打印值、从磁盘读取数据、更新全局状态等)。

您传递给jax.pure_callback的函数实际上不需要是纯的,但它将被 JAX 的转换和高阶函数假定为纯的,这意味着它可能会被静默地省略或多次调用。

import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
  # call a numpy (not jax.numpy) operation:
  return np.sin(x).astype(x.dtype)
def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)
x = jnp.arange(5.0)
f(x)
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 

因为pure_callback可以省略或复制,它与jitvmap等转换以及像scanwhile_loop这样的高阶原语兼容性开箱即用:“”

jax.jit(f)(x) 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 
jax.vmap(f)(x) 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] 
Array([ 0\.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32) 

然而,由于 JAX 无法审视回调的内容,因此pure_callback具有未定义的自动微分语义:

%xmode minimal 
Exception reporting mode: Minimal 
jax.grad(f)(x) 
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients. 

有关使用pure_callbackjax.custom_jvp的示例,请参见下文示例:pure_callbackcustom_jvp

通过设计传递给pure_callback的函数被视为没有副作用:这意味着如果函数的输出未被使用,编译器可能会完全消除回调:

def print_something():
  print('printing something')
  return np.int32(0)
@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1(); 
printing something 
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2(); 

f1中,回调的输出在函数返回值中被使用,因此执行回调并且我们看到打印的输出。另一方面,在f2中,回调的输出未被使用,因此编译器注意到这一点并消除函数调用。这是对没有副作用的函数回调的正确语义。

探索jax.experimental.io_callback

jax.pure_callback()相比,jax.experimental.io_callback()明确用于与有副作用的函数一起使用,即具有副作用的函数。

例如,这是一个对全局主机端 numpy 随机生成器的回调。这是一个不纯的操作,因为在 numpy 中生成随机数的副作用是更新随机状态(请注意,这只是io_callback的玩具示例,并不一定是在 JAX 中生成随机数的推荐方式!)。

from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x) 
generating float32[5] 
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32) 

io_callback默认与vmap兼容:

jax.vmap(numpy_random_like)(x) 
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[] 
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32) 

但请注意,这可能以任何顺序执行映射的回调。例如,如果在 GPU 上运行此代码,则映射输出的顺序可能会因每次运行而异。

如果保留回调的顺序很重要,可以设置ordered=True,在这种情况下,尝试vmap会引发错误:

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x) 
JaxStackTraceBeforeTransformation: ValueError: Cannot `vmap` ordered IO callback.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError: Cannot `vmap` ordered IO callback. 

另一方面,scanwhile_loop无论是否强制顺序,都与io_callback兼容:

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] 
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[] 
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32) 

pure_callback类似,如果向其传递不同的变量,io_callback在自动微分下会失败:

jax.grad(numpy_random_like)(x) 
JaxStackTraceBeforeTransformation: ValueError: IO callbacks do not support JVP.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
ValueError: IO callbacks do not support JVP. 

然而,如果回调不依赖于不同的变量,它将执行:

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x
jax.grad(f)(1.0); 
hello 

pure_callback不同,在此情况下编译器不会消除回调的执行,即使回调的输出在后续计算中未使用。

探索debug.callback

pure_callbackio_callback都对调用的函数的纯度做出了一些假设,并以各种方式限制了 JAX 的变换和编译机制的操作。而debug.callback基本上不对回调函数做出任何假设,因此在程序执行过程中完全反映了 JAX 的操作。此外,debug.callback不能向程序返回任何值。

from jax import debug
def log_value(x):
  # This could be an actual logging call; we'll use
  # print() for demonstration
  print("log:", x)
@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x
f(1.0); 
log: 1.0 

调试回调兼容vmap

x = jnp.arange(5.0)
jax.vmap(f)(x); 
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0 

也兼容grad和其他自动微分转换。

jax.grad(f)(1.0); 
log: 1.0 

这可以使得debug.callbackpure_callbackio_callback更有用于通用调试。

示例:pure_callbackcustom_jvp

利用jax.pure_callback()的一个强大方式是将其与jax.custom_jvp结合使用(详见自定义导数规则了解更多关于custom_jvp的细节)。假设我们想要为尚未包含在jax.scipyjax.numpy包装器中的 scipy 或 numpy 函数创建一个 JAX 兼容的包装器。

在这里,我们考虑创建一个第一类贝塞尔函数的包装器,该函数实现在scipy.special.jv中。我们可以先定义一个简单的pure_callback

import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)
  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)
  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))
  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)
  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) 

这使得我们可以从转换后的 JAX 代码中调用scipy.special.jv,包括使用jitvmap转换时:

from functools import partial
j1 = partial(jv, 1)
z = jnp.arange(5.0) 
print(j1(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

这里是使用jit得到的相同结果:

print(jax.jit(j1)(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

并且这里再次是使用vmap得到的相同结果:

print(jax.vmap(j1)(z)) 
[ 0\.          0.44005057  0.5767248   0.33905897 -0.06604332] 

然而,如果我们调用jax.grad,我们会看到一个错误,因为该函数没有定义自动微分规则:

jax.grad(j1)(z) 
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients. 

让我们为此定义一个自定义梯度规则。查看第一类贝塞尔函数的定义(Bessel Function of the First Kind),我们发现对于其关于参数z的导数有一个相对简单的递推关系:

[\begin{split} d J_\nu(z) = \left{ \begin{eqnarray} -J_1(z),\ &\nu=0\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}

\begin{split} d J_\nu(z) = \left{ \begin{eqnarray} -J_1(z),\ &\nu=0\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}\begin{split} d J_\nu(z) = \left{ \begin{eqnarray} -J_1(z),\ &\nu=0\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}

\begin{split} d J_\nu(z) = \left{ \begin{eqnarray} -J_1(z),\ &\nu=0\ [J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 \end{eqnarray}\right. \end{split}]

对于变量 (\nu) 的梯度更加复杂,但由于我们将v参数限制为整数类型,因此在这个例子中,我们不需要担心其梯度。

我们可以使用jax.custom_jvp来为我们的回调函数定义这个自动微分规则:

jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # Note: v_dot is always 0 because v is integer.
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz 

现在计算我们函数的梯度将会正确运行:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0)) 
-0.06447162 

此外,由于我们已经根据jv定义了我们的梯度,JAX 的架构意味着我们可以免费获得二阶及更高阶的导数:

jax.hessian(j1)(2.0) 
Array(-0.4003078, dtype=float32, weak_type=True) 

请记住,尽管这在 JAX 中完全正常运作,每次调用基于回调的jv函数都会导致将输入数据从设备传输到主机,并将scipy.special.jv的输出从主机传输回设备。当在 GPU 或 TPU 等加速器上运行时,这种数据传输和主机同步可能会导致每次调用jv时的显著开销。然而,如果您在单个 CPU 上运行 JAX(其中“主机”和“设备”位于同一硬件上),JAX 通常会以快速、零拷贝的方式执行此数据传输,使得这种模式相对直接地扩展了 JAX 的能力。


JAX 中文文档(四)(3)https://developer.aliyun.com/article/1559795

相关文章
|
9天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
12 1
|
9天前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
10 0
|
9天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0
|
9天前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
11 0
|
9天前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
12 0
|
9天前
JAX 中文文档(九)(3)
JAX 中文文档(九)
9 0
|
9天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
10 0
|
9天前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
10 0
|
9天前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
7 0
|
9天前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
11 0