JAX 中文文档(五)(1)https://developer.aliyun.com/article/1559807
维度变量必须能够从输入形状中解决
目前,当调用导出对象时,通过数组参数的形状间接传递维度变量的值是唯一的方法。例如,可以在调用类型为f32[b]
的第一个参数的形状中推断出b
的值。这对大多数用例都很有效,并且它反映了 JIT 函数的调用约定。
有时您可能希望导出一个由整数值参数化的函数,这些值确定程序中的某些形状。例如,我们可能希望导出下面定义的函数my_top_k
,其由值k
参数化,该值确定了结果的形状。下面的尝试将导致错误,因为维度变量k
不能从输入x: i32[4, 10]
的形状中推导出来:
>>> def my_top_k(k, x): # x: i32[4, 10], k <= 10 ... return lax.top_k(x, k)[0] # : i32[4, 3] >>> x = np.arange(40, dtype=np.int32).reshape((4, 10)) >>> # Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`. >>> exp_static_k = export.export(jax.jit(my_top_k, static_argnums=0))(3, x) >>> exp_static_k.in_avals[0] ShapedArray(int32[4,10]) >>> exp_static_k.out_avals[0] ShapedArray(int32[4,3]) >>> # When calling the exported function we pass only the non-static arguments >>> exp_static_k.call(x) Array([[ 9, 8, 7], [19, 18, 17], [29, 28, 27], [39, 38, 37]], dtype=int32) >>> # Now attempt to export with symbolic `k` so that we choose `k` after export. >>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) >>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) Traceback (most recent call last): KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments
未来,我们可能会添加额外的机制来传递维度变量的值,除了通过输入形状隐式传递外。与此同时,解决上述用例的方法是将函数参数k
替换为形状为(0, k)
的数组,这样k
可以从数组的输入形状中推导出来。第一个维度为 0 是为了确保整个数组为空,在调用导出函数时不会有性能惩罚。
>>> def my_top_k_with_dimensions(dimensions, x): # dimensions: i32[0, k], x: i32[4, 10] ... return my_top_k(dimensions.shape[1], x) >>> exp = export.export(jax.jit(my_top_k_with_dimensions))( ... jax.ShapeDtypeStruct((0, k), dtype=np.int32), ... x) >>> exp.in_avals (ShapedArray(int32[0,k]), ShapedArray(int32[4,10])) >>> exp.out_avals[0] ShapedArray(int32[4,k]) >>> # When we invoke `exp` we must construct and pass an array of shape (0, k) >>> exp.call(np.zeros((0, 3), dtype=np.int32), x) Array([[ 9, 8, 7], [19, 18, 17], [29, 28, 27], [39, 38, 37]], dtype=int32)
另一种可能出现错误的情况是一些维度变量出现在输入形状中,但以 JAX 目前无法解决的非线性表达式形式出现:
>>> a, = export.symbolic_shape("a") >>> export.export(jax.jit(lambda x: x.shape[0]))( ... jax.ShapeDtypeStruct((a * a,), dtype=np.int32)) Traceback (most recent call last): ValueError: Cannot solve for values of dimension variables {'a'}. We can only solve linear uni-variate constraints. Using the following polymorphic shapes specifications: args[0].shape = (a²,). Unprocessed specifications: 'a²' for dimension size args[0].shape[0].
形状断言错误
JAX 假设维度变量在严格正整数范围内,这一假设在为具体输入形状编译代码时被检查。
例如,对于符号输入形状(b, b, 2*d)
,当使用实际参数arg
调用时,JAX 将生成代码来检查以下断言:
arg.shape[0] >= 1
arg.shape[1] == arg.shape[0]
arg.shape[2] % 2 == 0
arg.shape[2] // 2 >= 1
例如,这是在对形状为(3, 3, 5)
的参数调用导出函数时得到的错误:
>>> def f(x): # x: f32[b, b, 2*d] ... return x >>> exp = export.export(jax.jit(f))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"), dtype=np.int32)) >>> exp.call(np.ones((3, 3, 5), dtype=np.int32)) Traceback (most recent call last): ValueError: Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), . Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details.
这些错误出现在编译之前的预处理步骤中。
部分支持符号维度的除法
JAX 将尝试简化除法和取模运算,例如(a * b + a) // (b + 1) == a
和6*a + 4 % 3 == 1
。特别地,JAX 会处理以下情况:要么(a)没有余数,要么(b)除数是一个常数,此时可能有一个常数余数。
例如,尝试计算reshape
操作的推断维度时,以下代码会导致除法错误:
>>> b, = export.symbolic_shape("b") >>> export.export(jax.jit(lambda x: x.reshape((2, -1))))( ... jax.ShapeDtypeStruct((b,), dtype=np.int32)) Traceback (most recent call last): jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1). The remainder mod(b, - 2) should be 0.
注意以下操作将成功:
>>> b, = export.symbolic_shape("b") >>> # We specify that the first dimension is a multiple of 4 >>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( ... jax.ShapeDtypeStruct((4*b,), dtype=np.int32)) >>> exp.out_avals (ShapedArray(int32[2,2*b]),) >>> # We specify that some other dimension is even >>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))( ... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32)) >>> exp.out_avals (ShapedArray(int32[2,15*b]),)
与 TensorFlow 的互操作
参见JAX2TF 文档。
JAX 错误
此页面列出了在使用 JAX 时可能遇到的一些错误,以及如何修复它们的代表性示例。
class jax.errors.ConcretizationTypeError(tracer, context='')
当 JAX 追踪器对象在需要具体值的上下文中使用时(参见关于 Tracer 是什么的更多信息),会发生此错误。在某些情况下,可以通过将问题值标记为静态来轻松修复;在其他情况下,可能表明您的程序正在执行 JAX JIT 编译模型不直接支持的操作。
例子:
在期望静态值的位置使用跟踪值
导致此错误的一个常见原因是在需要静态值的位置使用跟踪值。例如:
>>> from functools import partial >>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, axis): ... return x.min(axis)
>>> func(jnp.arange(4), 0) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: axis argument to jnp.min().
通常可以通过将问题参数标记为静态来解决此问题:
>>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return x.min(axis) >>> func(jnp.arange(4), 0) Array(0, dtype=int32)
形状依赖于跟踪的值
在 JIT 编译的计算中,如果形状依赖于跟踪数量中的值时,也可能出现此类错误。例如:
>>> @jit ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
这是一个与 JAX JIT 编译模型不兼容的操作示例,该模型要求在编译时知道数组大小。这里返回的数组大小取决于 x 的内容,这样的代码不能 JIT 编译。
在许多情况下,可以通过修改函数中使用的逻辑来解决此问题;例如,这里是一个类似问题的代码:
>>> @jit ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
以下是如何以避免创建动态大小索引数组的方式表达相同操作的示例:
>>> @jit ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)
要了解与跟踪器与常规值,具体与抽象值相关的更多细微差别,可以阅读有关不同类型的 JAX 值的内容。
参数:
- 追踪器 (core.Tracer)
- 上下文 (str)
class jax.errors.KeyReuseError(message)
当 PRNG 密钥以不安全的方式重复使用时,会发生此错误。仅在设置 jax_debug_key_reuse
为 True 时检查密钥重复使用。
以下是导致此类错误的代码简单示例:
>>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... value = jax.random.uniform(key) ... new_value = jax.random.uniform(key) ... --------------------------------------------------------------------------- KeyReuseError Traceback (most recent call last) ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
此类密钥重用存在问题,因为 JAX PRNG 是无状态的,必须手动分割密钥;有关更多信息,请参见 Sharp Bits: Random Numbers。
参数:
消息 (str)
class jax.errors.NonConcreteBooleanIndexError(tracer)
当程序尝试在跟踪索引操作中使用非具体布尔索引时,会发生此错误。在 JIT 编译下,JAX 数组必须具有静态形状(即在编译时已知的形状),因此布尔掩码必须小心使用。某些逻辑通过布尔掩码实现可能在 jax.jit()
函数中根本不可能;在其他情况下,可以使用 where()
的三参数版本以 JIT 兼容的方式重新表达逻辑。
以下是可能导致此错误的几个示例。
通过布尔掩码构建数组
在尝试在 JIT 上下文中通过布尔遮罩创建数组时最常见出现此错误。例如:
>>> import jax >>> import jax.numpy as jnp >>> @jax.jit ... def positive_values(x): ... return x[x > 0] >>> positive_values(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
此函数试图仅返回输入数组中的正值;除非将 x 标记为静态,否则在编译时无法确定返回数组的大小,因此无法在 JIT 编译下执行此类操作。
可重新表达的布尔逻辑
尽管不直接支持创建动态大小的数组,但在许多情况下可以重新表达计算逻辑以符合 JIT 兼容的操作。例如,以下是另一个因相同原因在 JIT 下失败的函数:
>>> @jax.jit ... def sum_of_positive(x): ... return x[x > 0].sum() >>> sum_of_positive(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
然而,在这种情况下,有问题的数组仅是一个中间值,我们可以使用支持 JIT 的三参数版本的 jax.numpy.where()
表达相同的逻辑:
>>> @jax.jit ... def sum_of_positive(x): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) Array(10, dtype=int32)
将布尔遮罩替换为带有三个参数的 where()
的模式是解决这类问题的常见方法。
对 JAX 数组进行布尔索引
另一个经常出现此错误的情况是使用布尔索引,例如 .at[...].set(...)
。以下是一个简单的示例:
>>> @jax.jit ... def manual_clip(x): ... return x.at[x < 0].set(0) >>> manual_clip(jnp.arange(-2, 2)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
此函数试图将小于零的值设置为标量填充值。与上述类似,可以通过在 where()
中重新表达逻辑来解决此问题:
>>> @jax.jit ... def manual_clip(x): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) Array([0, 0, 0, 1], dtype=int32)
参数:
tracer (core.Tracer)
class jax.errors.TracerArrayConversionError(tracer)
当程序尝试将 JAX 追踪对象转换为标准的 NumPy 数组时会发生此错误(详见不同类型的 JAX 值,了解追踪器的更多信息)。通常情况下会发生在几种情况之一。
在 JAX 变换中使用非 JAX 函数
如果尝试在 JAX 变换(jit()
、grad()
、jax.vmap()
等)内部使用非 JAX 库如 numpy
或 scipy
,则可能会导致此错误。例如:
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x): ... return np.sin(x) >>> func(np.arange(4)) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[4]
在这种情况下,你可以通过使用 jax.numpy.sin()
替换 numpy.sin()
来解决问题:
>>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) >>> func(jnp.arange(4)) Array([0\. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
另请参阅 External Callbacks 了解从转换的 JAX 代码返回到主机端计算的选项。
使用追踪器索引 numpy 数组
如果此错误出现在涉及数组索引的行上,则可能是被索引的数组 x
是标准的 numpy.ndarray,而索引 idx
是追踪的 JAX 数组。例如:
>>> x = np.arange(10) >>> @jit ... def func(i): ... return x[i] >>> func(0) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[0]
根据上下文,你可以通过将 numpy 数组转换为 JAX 数组来解决此问题:
>>> @jit ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)
或者通过将索引声明为静态参数:
>>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] >>> func(0) Array(0, dtype=int32)
JAX 中文文档(五)(3)https://developer.aliyun.com/article/1559810