形状多态性
当使用 JIT 模式的 JAX 时,函数将被跟踪、降级到 StableHLO,并针对每种输入类型和形状组合进行编译。在导出函数并在另一个系统上反序列化后,我们就无法再使用 Python 源代码,因此无法重新跟踪和重新降级它。形状多态性是 JAX 导出的一个特性,允许一些导出函数用于整个输入形状家族。这些函数在导出时只被跟踪和降级一次,并且Exported
对象包含编译和执行该函数所需的信息,可以在许多具体输入形状上进行编译和执行。我们通过在导出时指定包含维度变量(符号形状)的形状来实现这一点,例如下面的示例:
>>> import jax >>> from jax import export >>> from jax import numpy as jnp >>> def f(x): # f: f32[a, b] ... return jnp.concatenate([x, x], axis=1) >>> # We construct symbolic dimension variables. >>> a, b = export.symbolic_shape("a, b") >>> # We can use the symbolic dimensions to construct shapes. >>> x_shape = (a, b) >>> x_shape (a, b) >>> # Then we export with symbolic shapes: >>> exp: export.Exported = export.export(jax.jit(f))( ... jax.ShapeDtypeStruct(x_shape, jnp.int32)) >>> exp.in_avals (ShapedArray(int32[a,b]),) >>> exp.out_avals (ShapedArray(int32[a,2*b]),) >>> # We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`. >>> res = exp.call(np.ones((3, 4), dtype=np.int32)) >>> res.shape (3, 8)
注意,此类函数仍会按需为每个具体输入形状重新编译。仅跟踪和降级是保存的。
在上面的示例中,jax.export.symbolic_shape()
用于解析符号形状的字符串表示,将其转换为可以用于构造形状的维度表达式对象(类型为 _DimExpr
)。维度表达式对象重载了大多数整数运算符,因此在大多数情况下可以像使用整数常量一样使用它们。详细信息请参阅使用维度变量进行计算。
另外,我们提供了jax.export.symbolic_args_specs()
,可用于根据多态形状规范构建jax.ShapeDtypeStruct
对象的 pytrees:
>>> def f1(x, y): # x: f32[a, 1], y : f32[a, 4] ... return x + y >>> # Assuming you have some actual args with concrete shapes >>> x = np.ones((3, 1), dtype=np.int32) >>> y = np.ones((3, 4), dtype=np.int32) >>> args_specs = export.symbolic_args_specs((x, y), "a, ...") >>> exp = export.export(jax.jit(f1))(* args_specs) >>> exp.in_avals (ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))
注意多态形状规范中的 "a, ..."
如何包含占位符 ...
,以从参数 (x, y)
的具体形状中填充。占位符 ...
代表 0 个或多个维度,而占位符 _
代表一个维度。jax.export.symbolic_args_specs()
支持参数的 pytrees,用于填充 dtypes 和任何占位符。该函数将构造与传递给它的参数结构相匹配的参数规范 pytree (jax.ShapeDtypeStruct
)。在某些情况下,多个参数应用相同规范的前缀,如上例所示。请参阅如何将可选参数匹配到参数。
几个形状规范的示例:
("(b, _, _)", None)
可以用于具有两个参数的函数,第一个是具有应为符号的批处理前导维度的三维数组。基于实际参数专门化第一个参数的其他维度和第二个参数的形状。请注意,如果第一个参数是具有相同前导维度但可能具有不同尾部维度的多个三维数组的 pytree,则相同的规范也适用。第二个参数的值None
表示该参数不是符号化的。等效地,可以使用...
。("(batch, ...)", "(batch,)")
指定两个参数具有匹配的前导维度,第一个参数至少具有秩为 1,第二个具有秩为 1。
形状多态的正确性
我们希望信任导出的程序在编译和执行适用于任何具体形状时产生与原始 JAX 程序相同的结果。更确切地说:
对于任何 JAX 函数f
和包含符号形状的参数规范arg_spec
,以及任何形状与arg_spec
匹配的具体参数arg
:
- 如果 JAX 本地执行在具体参数上成功:
res = f(arg)
, - 如果导出使用符号形状成功:
exp = export.export(f)(arg_spec)
, - 编译和运行导出程序将会成功并得到相同的结果:
res == exp.call(arg)
非常重要的是理解f(arg)
有自由重新调用 JAX 追踪机制,实际上对于每个不同的具体arg
形状都会这样做,而exp.call(arg)
的执行不能再使用 JAX 追踪(这种执行可能发生在无法访问f
源代码的环境中)。
确保这种正确性形式是困难的,在最困难的情况下,导出会失败。本章的其余部分描述了如何处理这些失败。
使用维度变量进行计算
JAX 跟踪所有中间结果的形状。当这些形状依赖于维度变量时,JAX 将它们计算为涉及维度变量的符号形状表达式。维度变量代表大于或等于 1 的整数值。这些符号表达式可以表示应用算术运算符(add、sub、mul、floordiv、mod,包括 NumPy 变体 np.sum
、np.prod
等)在维度表达式和整数上的结果(int
、np.int
,或者通过operator.index
可转换的任何内容)。这些符号维度随后可以在 JAX 原语和 API 的形状参数中使用,例如在jnp.reshape
、jnp.arange
、切片索引等。
例如,在以下代码中展平二维数组时,计算x.shape[0] * x.shape[1]
将计算符号维度4 * b
作为新形状:
>>> f = lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)) >>> arg_spec = jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"), jnp.int32) >>> exp = export.export(jax.jit(f))(arg_spec) >>> exp.out_avals (ShapedArray(int32[4*b]),)
可以将维度表达式明确转换为 JAX 数组,例如jnp.array(x.shape[0])
甚至jnp.array(x.shape)
。这些操作的结果可以用作常规的 JAX 数组,但不能再作为形状中的维度使用。
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) >>> exp.call(jnp.arange(3, dtype=np.int32)) Array([3, 4, 5], dtype=int32) >>> exp = export.export(jax.jit(lambda x: x.reshape(jnp.array(x.shape[0]) + 2)))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b"), np.int32)) Traceback (most recent call last): TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].
当符号维度与非整数(如 float
、np.float
、np.ndarray
或 JAX 数组)进行算术运算时,它会自动转换为 JAX 数组,使用 jnp.array
。例如,在下面的函数中,x.shape[0]
的所有出现都会被隐式转换为 jnp.array(x.shape[0])
,因为它们与非整数标量或 JAX 数组参与了运算:
>>> exp = export.export(jax.jit( ... lambda x: (5. + x.shape[0], ... x.shape[0] - np.arange(5, dtype=jnp.int32), ... x + x.shape[0] + jnp.sin(x.shape[0]))))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b"), jnp.int32)) >>> exp.out_avals (ShapedArray(float32[], weak_type=True), ShapedArray(int32[5]), ShapedArray(float32[b], weak_type=True)) >>> exp.call(jnp.ones((3,), jnp.int32)) (Array(8., dtype=float32, weak_type=True), Array([ 3, 2, 1, 0, -1], dtype=int32), Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))
另一个典型的例子是计算平均值(注意 x.shape[0]
如何自动转换为 JAX 数组):
>>> exp = export.export(jax.jit( ... lambda x: jnp.sum(x, axis=0) / x.shape[0]))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b, c"), jnp.int32)) >>> exp.call(jnp.arange(12, dtype=jnp.int32).reshape((3, 4))) Array([4., 5., 6., 7.], dtype=float32)
存在形状多态性的错误
大多数 JAX 代码假定 JAX 数组的形状是整数元组,但是使用形状多态性时,某些维度可能是符号表达式。这可能导致多种错误。例如,我们可以遇到通常的 JAX 形状检查错误:
>>> v, = export.symbolic_shape("v,") >>> export.export(jax.jit(lambda x, y: x + y))( ... jax.ShapeDtypeStruct((v,), dtype=np.int32), ... jax.ShapeDtypeStruct((4,), dtype=np.int32)) Traceback (most recent call last): TypeError: add got incompatible shapes for broadcasting: (v,), (4,). >>> export.export(jax.jit(lambda x: jnp.matmul(x, x)))( ... jax.ShapeDtypeStruct((v, 4), dtype=np.int32)) Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).
我们可以通过指定参数的形状(v, v)
来修复上述矩阵乘法示例。
部分支持符号维度的比较
在 JAX 内部存在多个形状比较的相等性和不等式比较,例如用于形状检查或甚至用于为某些原语选择实现。比较支持如下:
- 支持等式,但有一个注意事项:如果两个符号维度在所有维度变量的赋值下都表示相同的值,则等式求值为
True
,例如对于b + b == 2*b
;否则等式求值为False
。关于此行为的重要后果,请参见下文讨论。 - 不相等总是等于等式的否定。
- 不等式部分支持,类似于部分等式。然而,在这种情况下,我们考虑维度变量只取严格正整数。例如,
b >= 1
、b >= 0
、2 * a + b >= 3
是True
,而b >= 2
、a >= b
、a - b >= 0
是不确定的并会导致异常。
在无法将比较操作解析为布尔值的情况下,我们会引发 InconclusiveDimensionOperation
。例如,
import jax >>> export.export(jax.jit(lambda x: 0 if x.shape[0] + 1 >= x.shape[1] else 1))( ... jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'a + 1' >= 'b' is inconclusive. This error arises for comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved.
如果出现 InconclusiveDimensionOperation
,您可以尝试几种策略:
- 如果您的代码使用内置的
max
或min
,或者使用np.max
或np.min
,那么可以将它们替换为core.max_dim
和core.min_dim
,这样可以将不等式比较延迟到编译时,当形状已知时。 - 尝试使用
core.max_dim
和core.min_dim
重写条件语句,例如,代替d if d > 0 else 0
,您可以写成core.max_dim(d, 0)
。 - 尝试重写代码,减少对维度应为整数的依赖,并依赖于符号维度在大多数算术运算中作为整数的鸭子类型。例如,代替
int(d) + 5
写成d + 5
。 - 按照下面的说明指定符号约束。
用户指定的符号约束
默认情况下,JAX 假定所有维度变量的取值大于或等于 1,并试图从中推导出其他简单的不等式,例如:
a + 2 >= 3
,a * 2 >= 1
,a + b + c >= 3
,a // 4 >= 0
,a**2 >= 1
,等等。
如果将符号形状规范更改为维度大小的隐式约束,可以避免一些不等比较失败。例如,
- 你可以使用
2*b
作为维度来约束它为偶数且大于或等于 2。 - 你可以使用
b + 15
作为维度来约束它至少为 16。例如,如果没有+ 15
部分,以下代码会失败,因为 JAX 将希望验证切片大小至多不超过轴大小。
>>> _ = export.export(jax.jit(lambda x: x[0:16]))( ... jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"), dtype=np.int32))
这些隐式符号约束用于决定比较,并且在编译时检查,如下所述。
你也可以指定显式符号约束:
>>> # Introduce dimension variable with constraints. >>> a, b = export.symbolic_shape("a, b", ... constraints=("a >= b", "b >= 16")) >>> _ = export.export(jax.jit(lambda x: x[:x.shape[1], :16]))( ... jax.ShapeDtypeStruct((a, b), dtype=np.int32))
约束与隐式约束一起形成一个连接。你可以指定 >=
、<=
和 ==
约束。目前,JAX 对符号约束的推理支持有限:
- 对于形式为变量大于或等于或小于或等于常数的约束,你可以得到最大的功效。例如,从
a >= 16
和b >= 8
的约束中,我们可以推断出a + 2*b >= 32
。 - 当约束涉及更复杂的表达式时,例如从
a >= b + 8
我们可以推断出a - b >= 8
,但不能推断出a >= 9
。我们可能会在未来在这个领域有所改进。 - 等式约束被视为归一化规则。例如,
floordiv(a, b) = c
通过将所有左侧的出现替换为右侧来工作。只能有左侧是因子乘积的等式约束,例如a * b
,或4 * a
,或floordiv(a, b)
。因此,左侧不能包含顶层的加法或减法。
符号约束还可以帮助绕过 JAX 推理机制中的限制。例如,在下面的代码中,JAX 将尝试证明切片大小 x.shape[0] % 3
,即符号表达式 mod(b, 3)
,小于或等于轴大小 b
。对于所有严格正值的 b
来说,这是真的,但这并不是 JAX 符号比较规则能够证明的。因此,以下代码会引发错误:
from jax import lax >>> b, = export.symbolic_shape("b") >>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) >>> export.export(jax.jit(f))( ... jax.ShapeDtypeStruct((b,), dtype=np.int32)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): jax._src.export.shape_poly.InconclusiveDimensionOperation: Symbolic dimension comparison 'b' >= 'mod(b, 3)' is inconclusive. This error arises for comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a boolean value for all values of the symbolic dimensions involved.
一种选择是将代码限制为仅在轴大小是 3
的倍数上运行(通过在形状中用 3*b
替换 b
)。然后,JAX 将能够将模运算 mod(3*b, 3)
简化为 0
。另一种选择是添加一个带有确切不确定不等式的符号约束,JAX 正试图证明:
>>> b, = export.symbolic_shape("b", ... constraints=["b >= mod(b, 3)"]) >>> f = lambda x: lax.slice_in_dim(x, 0, x.shape[0] % 3) >>> _ = export.export(jax.jit(f))( ... jax.ShapeDtypeStruct((b,), dtype=np.int32))
就像隐式约束一样,显式符号约束在编译时使用相同的机制进行检查,如下所述。
符号维度范围
符号约束存储在一个αn jax.export.SymbolicScope
对象中,它会隐式地为每次调用jax.export.symbolic_shapes()
创建。您必须小心,不要混合使用不同范围的符号表达式。例如,下面的代码将失败,因为a1
和a2
使用了不同的范围(由不同调用jax.export.symbolic_shape()
创建):
>>> a1, = export.symbolic_shape("a,") >>> a2, = export.symbolic_shape("a,", constraints=("a >= 8",)) >>> a1 + a2 Traceback (most recent call last): ValueError: Invalid mixing of symbolic scopes for linear combination. Expected scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>) and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints: a >= 8
源自单次调用jax.export.symbolic_shape()
的符号表达式共享一个范围,并且可以在算术操作中混合使用。结果也将共享相同的范围。
您可以重复使用范围:
>>> a, = export.symbolic_shape("a,", constraints=("a >= 8",)) >>> b, = export.symbolic_shape("b,", scope=a.scope) # Reuse the scope of `a` >>> a + b # Allowed b + a
您也可以显式创建范围:
>>> my_scope = export.SymbolicScope() >>> c, = export.symbolic_shape("c", scope=my_scope) >>> d, = export.symbolic_shape("d", scope=my_scope) >>> c + d # Allowed d + c
JAX 跟踪使用部分以形状为键的缓存,并且如果它们使用不同的范围,则打印相同的符号形状将被视为不同的。
相等性比较的注意事项
相等比较返回False
,对于b + 1 == b
或b == 0
(在这种情况下,对于所有维度变量的值,维度肯定不同),但对于b == 1
和a == b
也是如此。这是不稳定的,我们应该引发core.InconclusiveDimensionOperation
,因为在某些估值下结果应该是True
,在其他估值下应该是False
。我们选择使相等性变得全面,从而允许不稳定性,因为否则在哈希碰撞存在时(哈希维度表达式或包含它们的对象时,如形状,core.AbstractValue
,core.Jaxpr
),我们可能会遇到虚假错误。除了哈希错误外,相等性的部分语义还会导致以下表达式的错误b == a or b == b
或b in [a, b]
,即使我们改变比较的顺序也能避免错误。
形式为if x.shape[0] != 1: raise NiceErrorMessage
的代码在处理相等性时也是合理的,但形式为if x.shape[0] != 1: return 1
的代码是不稳定的。
JAX 中文文档(五)(2)https://developer.aliyun.com/article/1559809