JAX 中文文档(三)(3)https://developer.aliyun.com/article/1559704
checkify
转换
原文:
jax.readthedocs.io/en/latest/debugging/checkify_guide.html
TL;DR checkify
允许您向您的 JAX 代码添加可jit
的运行时错误检查(例如越界索引)。使用checkify.checkify
转换与类似断言的checkify.check
函数一起向 JAX 代码添加运行时检查:
from jax.experimental import checkify import jax import jax.numpy as jnp def f(x, i): checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i) y = x[i] z = jnp.sin(y) return z jittable_f = checkify.checkify(f) err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2) print(err.get()) # >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
您还可以使用checkify
来自动添加常见的检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks checked_f = checkify.checkify(f, errors=errors) err, z = checked_f(jnp.ones((5,)), 100) err.throw() # ValueError: out-of-bounds indexing at <..>:7 (f) err, z = checked_f(jnp.ones((5,)), -1) err.throw() # ValueError: index needs to be non-negative! (check failed at <…>:6 (f)) err, z = checked_f(jnp.array([jnp.inf, 1]), 0) err.throw() # ValueError: nan generated by primitive sin at <...>:8 (f) err, z = checked_f(jnp.array([5, 1]), 0) err.throw() # if no error occurred, throw does nothing!
功能化检查
与 assert 类似的检查 API 本身不是函数纯粹的:它可以作为副作用引发 Python 异常,就像 assert 一样。因此,它不能与jit
、pmap
、pjit
或scan
分阶段执行:
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used # ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
但是checkify
转换功能化(或卸载)这些效果。一个经过checkify
转换的函数将错误值作为新输出返回,并保持函数纯粹。这种功能化意味着checkify
转换的函数可以与我们喜欢的任何分阶段/转换进行组合:
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100])) err.throw() """ ValueError: .. at mapped index 0: index needs to be non-negative! (check failed at :6 (f)) .. at mapped index 2: out-of-bounds indexing at <..>:7 (f) """
JAX 为什么需要checkify
?
在某些 JAX 转换下,您可以使用普通的 Python 断言表达运行时错误检查,例如仅使用jax.grad
和jax.numpy
时。
def f(x): assert x > 0., "must be positive!" return jnp.log(x) jax.grad(f)(0.) # ValueError: "must be positive!"
但是普通的断言在jit
、pmap
、pjit
或scan
中不起作用。在这些情况下,数值计算是在 Python 执行期间被分阶段地进行评估,因此数值值不可用:
jax.jit(f)(0.) # ConcretizationTypeError: "Abstract tracer value encountered ..."
在组合多个转换时,JAX 转换语义依赖于函数纯度,因此我们如何在不干扰所有这些的情况下提供一个错误机制?除了需要一个新的 API 之外,情况还更加棘手:XLA HLO 不支持断言或抛出错误,因此即使我们有一个能够分阶段断言的 JAX API,我们如何将这些断言降低到 XLA 呢?
您可以想象手动向函数添加运行时检查并通过值来传递表示错误:
def f_checked(x): error = x <= 0. result = jnp.log(x) return error, result err, y = jax.jit(f_checked)(0.) if err: raise ValueError("must be positive!") # ValueError: "must be positive!"
错误是由函数计算出的常规值,并且错误是在f_checked
外部引发的。f_checked
是函数式纯粹的,因此我们知道通过构造,它已经可以与jit
、pmap
、pjit
、scan
以及所有 JAX 的转换一起工作。唯一的问题是这些管道可能会很麻烦!
checkify
为您完成了这个重写工作:包括通过函数传递错误值、将检查重写为布尔操作并将结果与跟踪的错误值合并,并将最终错误值作为检查函数的输出返回:
def f(x): checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API return jnp.log(x) f_checked = checkify(f) err, x = jax.jit(f_checked)(-1.) err.throw() # ValueError: -1\. must be positive! (check failed at <...>:2 (f))
我们称这个过程为功能化或者通过调用检查引入的效果。 (在上面的“手动”示例中,错误值只是一个布尔值。checkify
的错误值在概念上类似,但还跟踪错误消息并公开抛出和获取方法;参见jax.experimental.checkify
)。checkify.check
还允许您通过将其作为格式参数提供给错误消息来将运行时值添加到您的错误消息中。
您现在可以手动为您的代码添加运行时检查,但 checkify
也可以自动添加常见错误的检查!考虑这些错误情况:
jnp.arange(3)[5] # out of bounds jnp.sin(jnp.inf) # NaN generated jnp.ones((5,)) / jnp.arange(5) # division by zero
默认情况下,checkify
仅释放 checkify.check
,不会捕获类似上述的错误。但如果您要求,checkify
也会自动在您的代码中添加检查。
def f(x, i): y = x[i] # i could be out of bounds. z = jnp.sin(y) # z could become NaN return z errors = checkify.user_checks | checkify.index_checks | checkify.float_checks checked_f = checkify.checkify(f, errors=errors) err, z = checked_f(jnp.ones((5,)), 100) err.throw() # ValueError: out-of-bounds indexing at <..>:7 (f) err, z = checked_f(jnp.array([jnp.inf, 1]), 0) err.throw() # ValueError: nan generated by primitive sin at <...>:8 (f)
基于 Sets 的 API,用于选择要启用的自动检查。详见 jax.experimental.checkify
获取更多详情。
在 JAX 变换下的 checkify
。
如上例所示,checkified 函数可以愉快地进行 jitted 处理。以下是 checkify
与其他 JAX 变换的几个示例。请注意,checkified 函数在功能上是纯粹的,并且应与所有 JAX 变换轻松组合!
jit
您可以安全地向 checkified 函数添加 jax.jit
,或者 checkify
一个 jitted 函数,两者都可以正常工作。
def f(x, i): return x[i] checkify_of_jit = checkify.checkify(jax.jit(f)) jit_of_checkify = jax.jit(checkify.checkify(f)) err, _ = checkify_of_jit(jnp.ones((5,)), 100) err.get() # out-of-bounds indexing at <..>:2 (f) err, _ = jit_of_checkify(jnp.ones((5,)), 100) # out-of-bounds indexing at <..>:2 (f)
vmap
/pmap
您可以 vmap
和 pmap
checkified 函数(或 checkify
映射函数)。映射一个 checkified 函数将为您提供一个映射的错误,该错误可以包含映射维度的每个元素的不同错误。
def f(x, i): checkify.check(i >= 0, "index needs to be non-negative!") return x[i] checked_f = checkify.checkify(f, errors=checkify.all_checks) errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100])) errs.throw() """ ValueError: at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f)) at mapped index 2: out-of-bounds indexing at <...>:3 (f) """
然而,checkify-of-vmap
将产生单个(未映射)的错误!
@jax.vmap def f(x, i): checkify.check(i >= 0, "index needs to be non-negative!") return x[i] checked_f = checkify.checkify(f, errors=checkify.all_checks) err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100])) err.throw() # ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
pjit
对于 checkified 函数的 pjit
可以正常工作,您只需为错误值输出的 out_axis_resources
指定额外的 None
。
def f(x): return x / x f = checkify.checkify(f, errors=checkify.float_checks) f = pjit( f, in_shardings=PartitionSpec('x', None), out_shardings=(None, PartitionSpec('x', None))) with jax.sharding.Mesh(mesh.devices, mesh.axis_names): err, data = f(input_data) err.throw() # ValueError: divided by zero at <...>:4 (f)
grad
如果您使用 checkify-of-grad
,还将对您的梯度计算进行检查:
def f(x): return x / (1 + jnp.sqrt(x)) grad_f = jax.grad(f) err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.) print(err.get()) >> nan generated by primitive mul at <...>:3 (f)
请注意,f
中没有乘法,但在其梯度计算中有乘法(这就是生成 NaN 的地方!)。因此,请使用 checkify-of-grad
为前向和后向传递操作添加自动检查。
checkify.check
仅应用于函数的主值。如果您想在梯度值上使用 check
,请使用 custom_vjp
:
@jax.custom_vjp def assert_gradient_negative(x): return x def fwd(x): return assert_gradient_negative(x), None def bwd(_, grad): checkify.check(grad < 0, "gradient needs to be negative!") return (grad,) assert_gradient_negative.defvjp(fwd, bwd) jax.grad(assert_gradient_negative)(-1.) # ValueError: gradient needs to be negative!
jax.experimental.checkify
的优势和限制
优势
- 您可以在任何地方使用它(错误只是“值”,并在像其他值一样的转换下直观地表现)。
- 自动插装:您无需对代码进行本地修改。相反,
checkify
可以为其所有部分添加插装!
限制
- 添加大量运行时检查可能很昂贵(例如,对每个原语添加 NaN 检查将增加计算中的许多操作)。
- 需要将错误值从函数中线程化并手动抛出错误。如果未显式抛出错误,则可能会错过错误!
- 抛出一个错误值将在主机上实现该错误值,这意味着它是一个阻塞操作,这会打败 JAX 的异步先行运行。
JAX 中文文档(三)(5)https://developer.aliyun.com/article/1559706