理解 Jaxpr
更新日期:2020 年 5 月 3 日(提交标识为 f1a46fe)。
从概念上讲,可以将 JAX 转换看作是首先对要转换的 Python 函数进行追踪特化,使其转换为一个小型且行为良好的中间形式,然后使用特定于转换的解释规则进行解释。JAX 能够在一个如此小的软件包中融合如此多的功能,其中一个原因是它从一个熟悉且灵活的编程接口(Python + NumPy)开始,并使用实际的 Python 解释器来完成大部分繁重的工作,将计算的本质提炼为一个简单的静态类型表达语言,具有有限的高阶特性。那种语言就是 jaxpr 语言。
并非所有 Python 程序都可以以这种方式处理,但事实证明,许多科学计算和机器学习程序可以。
在我们继续之前,有必要指出,并非所有的 JAX 转换都像上述描述的那样直接生成一个 jaxpr;有些转换(如微分或批处理)会在追踪期间逐步应用转换。然而,如果想要理解 JAX 内部工作原理,或者利用 JAX 追踪的结果,理解 jaxpr 是很有用的。
一个 jaxpr 实例表示一个带有一个或多个类型化参数(输入变量)和一个或多个类型化结果的函数。结果仅依赖于输入变量;没有从封闭作用域中捕获的自由变量。输入和输出具有类型,在 JAX 中表示为抽象值。代码中有两种相关的 jaxpr 表示,jax.core.Jaxpr
和 jax.core.ClosedJaxpr
。jax.core.ClosedJaxpr
表示部分应用的 jax.core.Jaxpr
,当您使用 jax.make_jaxpr()
检查 jaxpr 时获得。它具有以下字段:
jaxpr
是一个jax.core.Jaxpr
,表示函数的实际计算内容(如下所述)。consts
是一个常量列表。
jax.core.ClosedJaxpr
最有趣的部分是实际的执行内容,使用以下语法打印为 jax.core.Jaxpr
:
Jaxpr ::= { lambda Var* ; Var+. let Eqn* in [Expr+] }
其中:
- jaxpr 的参数显示为用
;
分隔的两个变量列表。第一组变量是引入的用于表示已提升的常量的变量。这些称为constvars
,在jax.core.ClosedJaxpr
中,consts
字段保存相应的值。第二组变量称为invars
,对应于跟踪的 Python 函数的输入。 Eqn*
是一个方程列表,定义了中间变量,这些变量指代中间表达式。每个方程将一个或多个变量定义为在某些原子表达式上应用基元的结果。每个方程仅使用输入变量和由前面的方程定义的中间变量。Expr+
:是 jaxpr 的输出原子表达式(文字或变量)列表。
方程式打印如下:
Eqn ::= Var+ = Primitive [ Param* ] Expr+
其中:
Var+
是要定义为基元调用的输出的一个或多个中间变量(某些基元可以返回多个值)。Expr+
是一个或多个原子表达式,每个表达式可以是变量或字面常量。特殊变量unitvar
或字面unit
,打印为*
,表示在计算的其余部分中不需要的值已被省略。也就是说,单元只是占位符。Param*
是基元的零个或多个命名参数,打印在方括号中。每个参数显示为Name = Value
。
大多数 jaxpr 基元是一阶的(它们只接受一个或多个Expr
作为参数):
Primitive := add | sub | sin | mul | ...
jaxpr 基元在jax.lax
模块中有文档。
例如,下面是函数func1
生成的 jaxpr 示例
>>> from jax import make_jaxpr >>> import jax.numpy as jnp >>> def func1(first, second): ... temp = first + jnp.sin(second) * 3. ... return jnp.sum(temp) ... >>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))) { lambda ; a:f32[8] b:f32[8]. let c:f32[8] = sin b d:f32[8] = mul c 3.0 e:f32[8] = add a d f:f32[] = reduce_sum[axes=(0,)] e in (f,) }
在这里没有 constvars,a
和b
是输入变量,它们分别对应于first
和second
函数参数。标量文字3.0
保持内联。reduce_sum
基元具有命名参数axes
,除了操作数e
。
请注意,即使执行调用 JAX 的程序构建了 jaxpr,Python 级别的控制流和 Python 级别的函数也会正常执行。这意味着仅因为 Python 程序包含函数和控制流,生成的 jaxpr 不一定包含控制流或高阶特性。
例如,当跟踪函数func3
时,JAX 将内联调用inner
和条件if second.shape[0] > 4
,并生成与之前相同的 jaxpr
>>> def func2(inner, first, second): ... temp = first + inner(second) * 3. ... return jnp.sum(temp) ... >>> def inner(second): ... if second.shape[0] > 4: ... return jnp.sin(second) ... else: ... assert False ... >>> def func3(first, second): ... return func2(inner, first, second) ... >>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8))) { lambda ; a:f32[8] b:f32[8]. let c:f32[8] = sin b d:f32[8] = mul c 3.0 e:f32[8] = add a d f:f32[] = reduce_sum[axes=(0,)] e in (f,) }
处理 PyTrees
在 jaxpr 中不存在元组类型;相反,基元接受多个输入并产生多个输出。处理具有结构化输入或输出的函数时,JAX 将对其进行扁平化处理,并在 jaxpr 中它们将显示为输入和输出的列表。有关更多详细信息,请参阅 PyTrees(Pytrees)的文档。
例如,以下代码产生与前面看到的相同的 jaxpr(具有两个输入变量,每个输入元组的一个)
>>> def func4(arg): # Arg is a pair ... temp = arg[0] + jnp.sin(arg[1]) * 3. ... return jnp.sum(temp) ... >>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))) { lambda ; a:f32[8] b:f32[8]. let c:f32[8] = sin b d:f32[8] = mul c 3.0 e:f32[8] = add a d f:f32[] = reduce_sum[axes=(0,)] e in (f,) }
常量变量
jaxprs 中的某些值是常量,即它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示;非标量数组常量则提升到顶级 jaxpr,其中它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)在书面上的约定中有所不同。
高阶基元
jaxpr 包括几个高阶基元。它们更复杂,因为它们包括子 jaxprs。
条件语句
JAX 可以跟踪普通的 Python 条件语句。要捕获动态执行的条件表达式,必须使用jax.lax.switch()
和jax.lax.cond()
构造函数,它们的签名如下:
lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
这两个都将在内部绑定一个名为 cond
的原始。jaxprs 中的 cond
原始反映了 lax.switch()
更一般签名的更多细节:它接受一个整数,表示要执行的分支的索引(被夹在有效索引范围内)。
例如:
>>> from jax import lax >>> >>> def one_of_three(index, arg): ... return lax.switch(index, [lambda x: x + 1., ... lambda x: x - 2., ... lambda x: x + 3.], ... arg) ... >>> print(make_jaxpr(one_of_three)(1, 5.)) { lambda ; a:i32[] b:f32[]. let c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a d:i32[] = clamp 0 c 2 e:f32[] = cond[ branches=( { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) } { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) } { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) } ) linear=(False,) ] d b in (e,) }
cond 原始有多个参数:
- branches 是对应于分支函数的 jaxprs。在这个例子中,这些函数分别使用一个输入变量
x
。- linear 是一个布尔值元组,由自动微分机制内部使用,用于编码在条件语句中线性使用的输入参数。
cond 原始的上述实例接受两个操作数。第一个(d
)是分支索引,然后 b
是要传递给 branches
中任何 jaxpr 的操作数(arg
)。
另一个例子,使用 lax.cond()
:
>>> from jax import lax >>> >>> def func7(arg): ... return lax.cond(arg >= 0., ... lambda xtrue: xtrue + 3., ... lambda xfalse: xfalse - 3., ... arg) ... >>> print(make_jaxpr(func7)(5.)) { lambda ; a:f32[]. let b:bool[] = ge a 0.0 c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b d:f32[] = cond[ branches=( { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) } { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) } ) linear=(False,) ] c a in (d,) }
在这种情况下,布尔谓词被转换为整数索引(0 或 1),branches
是对应于假和真分支的 jaxprs,按顺序排列。同样,每个函数都使用一个输入变量,分别对应于 xfalse
和 xtrue
。
下面的示例展示了当分支函数的输入是一个元组时,以及假分支函数包含被作为常量 hoisted 的 jnp.ones(1)
的更复杂情况
>>> def func8(arg1, arg2): # arg2 is a pair ... return lax.cond(arg1 >= 0., ... lambda xtrue: xtrue[0], ... lambda xfalse: jnp.array([1]) + xfalse[1], ... arg2) ... >>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) { lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let e:bool[] = ge b 0.0 f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e g:f32[1] = cond[ branches=( { lambda ; h:i32[1] i:f32[1] j:f32[]. let k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h l:f32[1] = add k j in (l,) } { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) } ) linear=(False, False, False) ] f a c d in (g,) }
虽然
就像条件语句一样,Python 循环在追踪期间是内联的。如果要捕获动态执行的循环,必须使用多个特殊操作之一,jax.lax.while_loop()
(一个原始)和 jax.lax.fori_loop()
(一个生成 while_loop 原始的辅助程序):
lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
在上述签名中,“C”代表循环“carry”值的类型。例如,这里是一个 fori 循环的示例
>>> import numpy as np >>> >>> def func10(arg, n): ... ones = jnp.ones(arg.shape) # A constant ... return lax.fori_loop(0, n, ... lambda i, carry: carry + ones * 3. + arg, ... arg + ones) ... >>> print(make_jaxpr(func10)(np.ones(16), 5)) { lambda ; a:f32[16] b:i32[]. let c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 d:f32[16] = add a c _:i32[] _:i32[] e:f32[16] = while[ body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let k:i32[] = add h 1 l:f32[16] = mul f 3.0 m:f32[16] = add j l n:f32[16] = add m g in (k, i, n) } body_nconsts=2 cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let r:bool[] = lt o p in (r,) } cond_nconsts=0 ] c a 0 b d in (e,) }
while 原始接受 5 个参数:c a 0 b d
,如下所示:
- 0 个常量用于
cond_jaxpr
(因为cond_nconsts
为 0)- 两个常量用于
body_jaxpr
(c
和a
)- 初始携带值的 3 个参数
Scan
JAX 支持数组元素的特殊形式循环(具有静态已知形状)。由于迭代次数固定,这种形式的循环易于反向可微分。这些循环是用 jax.lax.scan()
函数构造的:
lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
这是以 Haskell 类型签名 的形式编写的:C
是扫描携带的类型,A
是输入数组的元素类型,B
是输出数组的元素类型。
对于下面的函数 func11
的示例考虑
>>> def func11(arr, extra): ... ones = jnp.ones(arr.shape) # A constant ... def body(carry, aelems): ... # carry: running dot-product of the two arrays ... # aelems: a pair with corresponding elements from the two arrays ... ae1, ae2 = aelems ... return (carry + ae1 * ae2 + extra, carry) ... return lax.scan(body, 0., (arr, ones)) ... >>> print(make_jaxpr(func11)(np.ones(16), 5.)) { lambda ; a:f32[16] b:f32[]. let c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 d:f32[] e:f32[16] = scan[ _split_transpose=False jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let j:f32[] = mul h i k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g l:f32[] = add k j m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f n:f32[] = add l m in (n, g) } length=16 linear=(False, False, False, False) num_carry=1 num_consts=1 reverse=False unroll=1 ] b 0.0 a c in (d, e) }
linear
参数描述了每个输入变量在主体中是否保证线性使用。一旦扫描进行线性化,将有更多参数线性使用。
scan 原始接受 4 个参数:b 0.0 a c
,其中:
- 其中一个是主体的自由变量
- 其中一个是携带的初始值
- 接下来的两个是扫描操作的数组。
XLA_call
call
原语来源于 JIT 编译,它封装了一个子 jaxpr
和指定计算应在哪个后端和设备上运行的参数。例如
>>> from jax import jit >>> >>> def func12(arg): ... @jit ... def inner(x): ... return x + arg * jnp.ones(1) # Include a constant in the inner function ... return arg + inner(arg - 2.) ... >>> print(make_jaxpr(func12)(1.)) { lambda ; a:f32[]. let b:f32[] = sub a 2.0 c:f32[1] = pjit[ name=inner jaxpr={ lambda ; d:f32[] e:f32[]. let f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d h:f32[1] = mul g f i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e j:f32[1] = add i h in (j,) } ] a b k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a l:f32[1] = add k c in (l,) }
XLA_pmap
如果使用 jax.pmap()
变换,要映射的函数是使用 xla_pmap
原语捕获的。考虑这个例子
>>> from jax import pmap >>> >>> def func13(arr, extra): ... def inner(x): ... # use a free variable "extra" and a constant jnp.ones(1) ... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows') ... return pmap(inner, axis_name='rows')(arr) ... >>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) { lambda ; a:f32[1,3] b:f32[]. let c:f32[1,3] = xla_pmap[ axis_name=rows axis_size=1 backend=None call_jaxpr={ lambda ; d:f32[] e:f32[3]. let f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d g:f32[3] = add e f h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 i:f32[3] = add g h j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e k:f32[3] = div i j in (k,) } devices=None donated_invars=(False, False) global_axis_size=1 in_axes=(None, 0) is_explicit_global_axis_size=False name=inner out_axes=(0,) ] b a in (c,) }
xla_pmap
原语指定了轴的名称(参数 axis_name
)和要映射为 call_jaxpr
参数的函数体。此参数的值是一个具有 2 个输入变量的 Jaxpr。
参数 in_axes
指定了应该映射哪些输入变量和哪些应该广播。在我们的例子中,extra
的值被广播,arr
的值被映射。
JAX 中文文档(四)(2)https://developer.aliyun.com/article/1559793