JAX 中文文档(十)(3)https://developer.aliyun.com/article/1559709
第三部分:jit
,简化
虽然jit
具有类似于转换的 API,因为它接受 Python 可调用对象作为参数,但在幕后它实际上是一个高阶原语,而不是转换。当参数化为函数时,一个原语是高阶的。
即时(“final style”)和分阶段(“initial style”)处理
处理高阶原语有两个选择。每种选择都需要不同的跟踪方法,并产生不同的权衡:
- 即时处理,在
bind
将 Python 可调用对象作为参数。 我们推迟形成 jaxpr,直到可能的最后一刻,即在解释器栈底部运行最终解释器时。这样我们可以在解释器栈底部换上一个JaxprTrace
,从而分阶段而不是执行所有原始操作。采用这种方法,堆栈中的转换会在我们像往常一样执行 Python 可调用对象时应用。这种方法实现起来可能非常棘手,但尽可能通用,因为它允许高阶原语不提升其参数的抽象级别,从而允许数据相关的 Python 控制流。我们称之为使用“最终风格高阶原语”,采用了迄今为止使用的“追踪时排除”最终风格变换。 - 分阶段处理,在
bind
将 jaxpr 作为参数。 在我们调用bind
之前,在原始包装器中我们可以直接使用make_jaxpr
来预先形成 jaxpr 并完全结束 Python 可调用对象。在这种情况下,make_jaxpr
将其JaxprTrace
放在解释器栈的顶部,并且没有低于堆栈的变换会通过闭合的 Tracer 输入到我们追踪的 Python 可调用对象中。 (在 Python 可调用对象内部应用的转换会像往常一样应用,被添加到 JaxprTrace 之上的堆栈中。)相反,堆栈中较低的转换稍后将应用于调用原始操作,并且调用原始操作的规则必须然后转换 jaxpr 本身。由于我们预先追踪到一个 jaxpr,这种方法不能支持数据相关的 Python 控制流,但它实现起来更为直接。我们将这种类型的高阶原语称为“初始风格高阶原语”,并说其 jaxpr 处理转换规则是“初始风格变换规则”。
后一种方法适用于jit
,因为我们不需要支持用户提供的 Python 可调用对象中的数据相关 Python 控制流,因为jit
的整个目的是将计算从 Python 阶段出来以供 XLA 执行。(相反,custom_jvp
是一个高阶原语,我们希望在其中支持数据相关的 Python 控制流。)
在阅读了类型标签最终解释器论文后,我们从历史上开始使用“初始风格”和“最终风格”术语,并开玩笑称 JAX 是“未类型化的标签满足最终解释器”的实现。我们并不声称传承(或理解)这些术语背后的任何深层含义;我们宽泛地使用“初始风格”来表示“构建 AST 然后转换它”,并且我们使用“最终风格”来表示“追踪时转换”。但这只是不精确但易记的行话。
使用初始风格方法,这里是用户界面的jit
包装器:
def jit(f): def f_jitted(*args): avals_in = [raise_to_shaped(get_aval(x)) for x in args] jaxpr, consts, out_tree = make_jaxpr(f, *avals_in) outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts)) return tree_unflatten(out_tree, outs) return f_jitted xla_call_p = Primitive('xla_call')
对于任何新的原语,我们都需要为其提供转换规则,从其评估规则开始。当我们评估xla_call
原语的应用时,我们希望将计算分阶段到 XLA。这涉及将 jaxpr 转换为 XLA HLO 程序,将参数值传输到 XLA 设备,执行 XLA 程序,并将结果传输回来。我们将缓存 XLA HLO 编译,以便于每个jit
函数只需在参数形状和 dtype 签名上执行一次。
首先,一些实用工具。
class IDHashable: val: Any def __init__(self, val): self.val = val def __hash__(self) -> int: return id(self.val) def __eq__(self, other): return type(other) is IDHashable and id(self.val) == id(other.val)
接下来,我们将为xla_call
定义评估规则:
from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc xe = xc._xla xops = xc._xla.ops def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): consts, args = args[:num_consts], args[num_consts:] hashable_consts = tuple(map(IDHashable, consts)) execute = xla_callable(IDHashable(jaxpr), hashable_consts) return execute(*args) impl_rules[xla_call_p] = xla_call_impl @lru_cache() def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: tuple[IDHashable, ...]): jaxpr: Jaxpr = hashable_jaxpr.val typecheck_jaxpr(jaxpr) consts = [x.val for x in hashable_consts] in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]] c = xc.XlaBuilder('xla_call') xla_consts = _xla_consts(c, consts) xla_params = _xla_params(c, in_avals) outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params) out = xops.Tuple(c, outs) compiled = xb.get_backend(None).compile( xc._xla.mlir.xla_computation_to_mlir_module(c.build(out))) return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]: unique_consts = {id(cnst): cnst for cnst in consts} xla_consts = { id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()} return [xla_consts[id(cnst)] for cnst in consts] def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]: return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] def _xla_shape(aval: ShapedArray) -> xe.Shape: return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)
主要操作在xla_callable
中进行,它使用jaxpr_subcomp
将 jaxpr 编译成 XLA HLO 程序,然后返回一个可调用对象来执行编译后的程序:
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp] ) -> list[xe.XlaOp]: env: dict[Var, xe.XlaOp] = {} def read(x: Atom) -> xe.XlaOp: return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val)) def write(v: Var, val: xe.XlaOp) -> None: env[v] = val map(write, jaxpr.in_binders, args) for eqn in jaxpr.eqns: in_avals = [x.aval for x in eqn.inputs] in_vals = map(read, eqn.inputs) rule = xla_translations[eqn.primitive] out_vals = rule(c, in_avals, in_vals, **eqn.params) map(write, eqn.out_binders, out_vals) return map(read, jaxpr.outs) def execute_compiled(compiled, out_avals, *args): input_bufs = input_handlers[type(x) for x in args] out_bufs = compiled.execute(input_bufs) return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)] default_input_handler = xb.get_backend(None).buffer_from_pyval input_handlers = {ty: default_input_handler for ty in [bool, int, float, np.ndarray, np.float64, np.float32]} def handle_result(aval: ShapedArray, buf): del aval # Unused for now return np.asarray(buf) xla_translations = {}
请注意,jaxpr_subcomp
具有简单解释器的结构。这是一个常见模式:我们处理 jaxprs 的方式通常是使用解释器。与任何解释器一样,我们需要为每个原语定义一个解释规则:
def direct_translation(op, c, in_avals, in_vals): del c, in_avals return [op(*in_vals)] xla_translations[add_p] = partial(direct_translation, xops.Add) xla_translations[mul_p] = partial(direct_translation, xops.Mul) xla_translations[neg_p] = partial(direct_translation, xops.Neg) xla_translations[sin_p] = partial(direct_translation, xops.Sin) xla_translations[cos_p] = partial(direct_translation, xops.Cos) xla_translations[greater_p] = partial(direct_translation, xops.Gt) xla_translations[less_p] = partial(direct_translation, xops.Lt) def reduce_sum_translation(c, in_avals, in_vals, *, axis): (x_aval,), (x,) = in_avals, in_vals zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) subc = xc.XlaBuilder('add') shape = _xla_shape(ShapedArray((), x_aval.dtype)) xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) return [xops.Reduce(c, [x], [zero], subc.build(), axis)] xla_translations[reduce_sum_p] = reduce_sum_translation def broadcast_translation(c, in_avals, in_vals, *, shape, axes): x, = in_vals dims_complement = [i for i in range(len(shape)) if i not in axes] return [xops.BroadcastInDim(x, shape, dims_complement)] xla_translations[broadcast_p] = broadcast_translation
有了这个,我们现在可以使用jit
来分阶段、编译和执行 XLA 程序了!
@jit def f(x, y): print('tracing!') return sin(x) * cos(y)
z = f(3., 4.) # 'tracing!' prints the first time print(z)
tracing! -0.09224219304455371
z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit! print(z)
-0.21467624978306993
@jit def f(x): return reduce_sum(x, axis=0) print(f(np.array([1., 2., 3.])))
6.0
def f(x): y = sin(x) * 2. z = - y + x return z def deriv(f): return lambda x: jvp(f, (x,), (1.,))[1] print( deriv(deriv(f))(3.)) print(jit(deriv(deriv(f)))(3.))
0.2822400161197344 0.2822400161197344
而不是实现jit
以首先对 jaxpr 进行跟踪,然后将 jaxpr 降低到 XLA HLO,我们可能看起来可以跳过 jaxpr 步骤,而在跟踪时直接降低到 HLO。也就是说,也许我们可以用一个Trace
和Tracer
实现jit
,在每个原语绑定时逐步追加到 XLA HLO 图中。目前这样做是正确的,但当我们引入编译的 SPMD 计算时,就不可能了,因为在编译程序之前我们必须知道所需的副本数量。
我们尚未为xla_call_p
定义任何转换规则,除了其评估规则。也就是说,我们尚不能做vmap
-of-jit
或jvp
-of-jit
甚至jit
-of-jit
。相反,jit
必须处于“顶层”。让我们来修复这个问题!
def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): del num_consts # Unused new_jaxpr, new_consts = jvp_jaxpr(jaxpr) outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr, num_consts=len(new_consts)) n = len(outs) // 2 primals_out, tangents_out = outs[:n], outs[n:] return primals_out, tangents_out jvp_rules[xla_call_p] = xla_call_jvp_rule @lru_cache() def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]: def jvp_traceable(*primals_and_tangents): n = len(primals_and_tangents) // 2 primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:] return jvp(jaxpr_as_fun(jaxpr), primals, tangents) in_avals = [v.aval for v in jaxpr.in_binders] new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals) return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): del num_consts # Unused new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in)) outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr, num_consts=len(new_consts)) return outs, [0] * len(outs) vmap_rules[xla_call_p] = xla_call_vmap_rule @lru_cache() def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...] ) -> tuple[Jaxpr, list[Any]]: vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) in_avals = [unmapped_aval(axis_size, d, v.aval) for v, d in zip(jaxpr.in_binders, bdims_in)] new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals) return new_jaxpr, new_consts def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray ) -> ShapedArray: if batch_dim is not_mapped: return aval else: shape = list(aval.shape) shape.insert(batch_dim, axis_size) return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts): del num_consts # Unused jaxpr_type = typecheck_jaxpr(jaxpr) if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)): raise TypeError return jaxpr_type.out_types abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts): del num_consts # Only used at top-level. # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead. subc = xc.XlaBuilder('inner xla_call') xla_params = _xla_params(subc, in_avals) outs = jaxpr_subcomp(subc, jaxpr, xla_params) subc = subc.build(xops.Tuple(subc, outs)) return destructure_tuple(c, xops.Call(c, subc, in_vals)) xla_translations[xla_call_p] = xla_call_translation def destructure_tuple(c, tup): num_elements = len(c.get_shape(tup).tuple_shapes()) return [xops.GetTupleElement(tup, i) for i in range(num_elements)]
@jit def f(x): print('tracing!') y = sin(x) * 2. z = - y + x return z x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot)
tracing! 2.7177599838802657 2.979984993200891
y, ydot = jvp(f, (x,), (xdot,)) # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.)) print(ys)
[ 0\. -0.68294197 0.18140515]
一个遗漏的部分是数组的设备内存持久性。也就是说,我们已经定义了handle_result
以将结果作为 NumPy 数组传输回 CPU 内存,但通常最好避免仅为了下一步操作而传输结果。我们可以通过引入Array
类来实现这一点,它可以包装 XLA 缓冲区,同时鸭子类型numpy.ndarray
:
def handle_result(aval: ShapedArray, buf): # noqa: F811 return Array(aval, buf) class Array: buf: Any aval: ShapedArray def __init__(self, aval, buf): self.aval = aval self.buf = buf dtype = property(lambda self: self.aval.dtype) shape = property(lambda self: self.aval.shape) ndim = property(lambda self: self.aval.ndim) def __array__(self): return np.asarray(self.buf) def __repr__(self): return repr(np.asarray(self.buf)) def __str__(self): return str(np.asarray(self.buf)) _neg = staticmethod(neg) _add = staticmethod(add) _radd = staticmethod(add) _mul = staticmethod(mul) _rmul = staticmethod(mul) _gt = staticmethod(greater) _lt = staticmethod(less) input_handlers[Array] = lambda x: x.buf jax_types.add(Array)
@jit def f(x): y = sin(x) * 2. z = - y + x return z x, xdot = 3., 1. y, ydot = jvp(f, (x,), (xdot,)) print(y) print(ydot)
2.7177599838802657 2.979984993200891
显示代码单元格源码 隐藏代码单元格源码
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint: lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders)) params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'} rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >> pp(' '.join(names[x] if isinstance(x, Var) else str(x.val) for x in eqn.inputs))) return vcat([lhs >> pp(' = ') >> rhs, pp_jaxpr(eqn.params['jaxpr']).indent(2)]) pp_rules[xla_call_p] = pprint_xla_call ```</details> ## 第四部分:`linearize`和`vjp`(以及`grad`!) `linearize`和`vjp`的自动微分函数都建立在`jvp`之上,但也涉及 jaxprs。这是因为两者都涉及分阶段或延迟计算。 ### `linearize` 对于`linearize`的情况,我们想要分离出`jvp`计算的线性部分。也就是说,用[Haskell 类型签名](https://wiki.haskell.org/Type_signature)来说,如果我们有`jvp : (a -> b) -> (a, T a) -> (b, T b)`,那么我们会写成`linearize : (a -> b) -> a -> (b, T a -o T b)`,使用`T a`表示“`a`的切线类型”,并使用“棒棒糖”`-o`而不是箭头`->`来指示一个*线性*函数。我们也是以`jvp`的语义来定义`linearize`: ```py y, f_lin = linearize(f, x) y_dot = f_lin(x_dot)
对于(y, y_dot)
,与原先相同的结果如下:
y, y_dot = jvp(f, (x,), (x_dot,))
在应用f_lin
时,不会重新执行任何线性化工作。我们将延迟的线性部分f_lin : T a -o T b
表示为一个 jaxpr。
顺便说一句,既然我们有了线性箭头-o
,我们可以为jvp
提供一个稍微更详细的类型:
jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)
我们在这里编写UnrestrictedUse
只是为了表明我们有一个特殊的对,第一个元素可以以非线性的方式使用。与线性箭头结合使用时,此符号只是用来表示函数jvp f
以非线性方式使用其第一个输入,但以线性方式使用其第二个输入,生成相应的非线性输出(可以以非线性方式使用),与线性输出配对。这种更精细的类型签名编码了jvp f
中的数据依赖关系,对于部分评估非常有用。
要从 JVP 构建f_lin
的 jaxpr,我们需要执行部分评估:在我们追踪时评估所有原始值,但是将切线计算分阶段到一个 jaxpr 中。这是我们构建 jaxprs 的第二种方式。但是,与make_jaxpr
及其基础的JaxprTrace
/JaxprTracer
解释器的目标是分阶段所有原始绑定不同,这第二种方法仅分阶段那些具有对切线输入的数据依赖性的原始绑定。
首先,一些实用工具:
def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]: assert not len(lst) % 2 return split_list(lst, len(lst) // 2) def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]: l1, l2 = iter(l1), iter(l2) out = [next(l2) if b else next(l1) for b in which] assert next(l1, None) is next(l2, None) is None return out
接下来,我们将编写linearize
,通过将jvp
与一般的部分评估转换组合在一起:
def linearize_flat(f, *primals_in): pvals_in = ([PartialVal.known(x) for x in primals_in] + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) def f_jvp(*primals_tangents_in): primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) return [*primals_out, *tangents_out] jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) primal_pvals, _ = split_half(pvals_out) assert all(pval.is_known for pval in primal_pvals) primals_out = [pval.const for pval in primal_pvals] f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents]) return primals_out, f_lin def linearize(f, *primals_in): primals_in_flat, in_tree = tree_flatten(primals_in) f, out_tree = flatten_fun(f, in_tree) primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat) primals_out = tree_unflatten(out_tree(), primals_out_flat) def f_lin(*tangents_in): tangents_in_flat, in_tree2 = tree_flatten(tangents_in) if in_tree != in_tree2: raise TypeError tangents_out_flat = f_lin_flat(*tangents_in_flat) return tree_unflatten(out_tree(), tangents_out_flat) return primals_out, f_lin def vspace(aval: ShapedArray) -> ShapedArray: return raise_to_shaped(aval) # TODO handle integers?
现在我们转向一般的部分评估转换。目标是接受一个 Python 可调用函数和一个输入列表,其中一些已知,一些未知,并产生(1)可以从已知输入计算出来的所有输出,以及(2)表示仅在其余输入已知后才能执行的 Python 可调用函数计算的 japxr。
这种转换很难用一个类型签名来总结。如果我们假设输入函数的类型签名是(a1, a2) -> (b1, b2)
,其中a1
和a2
分别表示已知和未知的输入,并且其中b1
仅对a1
有数据依赖性,而b2
对a2
有一些数据依赖性,那么我们可能会写成:
partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)
简言之,给定类型为 a1
的输入值,partial_eval
将生成类型为 b1
的输出值以及代表在第二阶段完成计算所需的存在量化类型 r
的“残余”值。它还会生成一个类型为 (r, a2) -> b2
的函数,该函数接受残余值以及剩余输入,并生成剩余输出。
我们喜欢将部分评估视为将一个计算“解压”为两个的过程。例如,考虑以下 jaxpr
:
{ lambda a:float64[] . let b:float64[] = sin a c:float64[] = neg b in ( c ) }
JVP 的 jaxpr
将如下所示:
{ lambda a:float64[] b:float64[] . let c:float64[] = sin a d:float64[] = cos a e:float64[] = mul d b f:float64[] = neg c g:float64[] = neg e in ( f, g ) }
如果我们想象将部分评估应用于此 jaxpr
,第一个输入已知,第二个输入未知,我们将 JVP
的 jaxpr
“解压”为原始和切线 jaxpr
:
{ lambda a:float64[] . let c:float64[] = sin a d:float64[] = cos a f:float64[] = neg c in ( f, d ) }
{ lambda d:float64[] b:float64[] . let e:float64[] = mul d b g:float64[] = neg e in ( g ) }
这第二个 jaxpr
表示我们从 linearize
中希望得到的线性计算。
然而,与此 jaxpr
示例不同的是,我们希望在评估输入的 Python 可调用函数时,对已知值进行计算。换句话说,我们不想在整个函数 (a1, a2) -> (b1, b2)
的 jaxpr
中首先将所有操作分离出 Python,然后再确定哪些可以立即评估,哪些必须延迟。我们只想形成那些由于依赖于未知输入而必须延迟的操作的 jaxpr
。在自动微分的背景下,这正是使我们能够处理诸如 grad(lambda x: x**2 if x > 0 else 0.)
函数的特性。Python 控制流能够正常工作,因为部分评估保持了 Python 中的原始计算。因此,我们的 Trace
和 Tracer
子类必须动态地分辨出哪些可以评估,哪些必须分离到 jaxpr
中。
首先,我们从 PartialVal
类开始,它表示可以是已知或未知的值:
class PartialVal(NamedTuple): aval: ShapedArray const: Optional[Any] @classmethod def known(cls, val: Any): return PartialVal(get_aval(val), val) @classmethod def unknown(cls, aval: ShapedArray): return PartialVal(aval, None) is_known = property(lambda self: self.const is not None) is_unknown = property(lambda self: self.const is None)
部分评估将接受一个表示输入的 PartialVal
列表,并返回一个表示延迟计算的 jaxpr
的 PartialVal
输出列表:
def partial_eval_flat(f: Callable, pvals_in: list[PartialVal] ) -> tuple[Jaxpr, list[PartialVal], list[Any]]: with new_main(PartialEvalTrace) as main: trace = PartialEvalTrace(main) tracers_in = [trace.new_arg(pval) for pval in pvals_in] outs = f(*tracers_in) tracers_out = [full_raise(trace, out) for out in outs] pvals_out = [t.pval for t in tracers_out] unk_tracers_in = [t for t in tracers_in if t.pval.is_unknown] unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown] jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out) return jaxpr, pvals_out, consts
接下来,我们需要实现 PartialEvalTrace
及其 PartialEvalTracer
。此解释器将在跟踪数据依赖关系的同时动态构建 jaxpr
。为此,它在 PartialEvalTracer
节点(代表分阶段的值)和 JaxprRecipe
节点(代表如何从其他值计算某些值的公式)之间建立了一个二部有向无环图(DAG)。一种类型的配方是 JaxprEqnRecipe
,对应于 JaxprEqn
的原语应用,但我们还有常量和 Lambda 绑定器的配方类型:
from weakref import ref, ReferenceType class LambdaBindingRecipe(NamedTuple): pass class ConstRecipe(NamedTuple): val: Any class JaxprEqnRecipe(NamedTuple): prim: Primitive tracers_in: list['PartialEvalTracer'] params: dict[str, Any] avals_out: list[ShapedArray] tracer_refs_out: list['ReferenceType[PartialEvalTracer]'] JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer): pval: PartialVal recipe: Optional[JaxprRecipe] def __init__(self, trace, pval, recipe): self._trace = trace self.pval = pval self.recipe = recipe aval = property(lambda self: self.pval.aval) def full_lower(self): if self.pval.is_known: return full_lower(self.pval.const) return self
PartialEvalTrace
包含构建 JaxprRecipe
和 PartialEvalTracer
图形的逻辑。每个参数对应于 LambdaBindingRecipe
叶节点,每个常量都是一个 ConstRecipe
叶节点,保存对常量的引用。所有其他跟踪器和配方都来自 process_primitive
,它使用 JaxprEqnRecipe
形成具有 JaxprEqn
的原语应用的跟踪器。
对于大多数原语,process_primitive
逻辑很简单:如果所有输入都已知,我们可以在已知值上绑定原语(在 Python 中评估它),并避免形成对应于输出的追踪器。如果任何输入未知,则我们转而进行JaxprEqnRecipe
的阶段输出,表示原语应用。为了构建代表未知输出的追踪器,我们需要 aval,这些 aval 来自抽象评估规则。(请注意,追踪器引用JaxprEqnRecipe
,而JaxprEqnRecipe
引用追踪器;我们通过使用弱引用来避免循环垃圾。)
process_primitive
逻辑适用于大多数原语,但xla_call_p
需要递归处理。因此,我们在partial_eval_rules
字典中特别处理它的规则。
class PartialEvalTrace(Trace): def new_arg(self, pval: PartialVal) -> Any: return PartialEvalTracer(self, pval, LambdaBindingRecipe()) def lift(self, val: Any) -> PartialEvalTracer: return PartialEvalTracer(self, PartialVal.known(val), None) pure = lift def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer: if tracer.pval.is_unknown: return tracer else: pval = PartialVal.unknown(raise_to_shaped(tracer.aval)) return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const)) def process_primitive(self, primitive, tracers, params): if all(t.pval.is_known for t in tracers): return bind(primitive, *map(full_lower, tracers), **params) rule = partial_eval_rules.get(primitive) if rule: return rule(self, tracers, **params) tracers_in = [self.instantiate_const(t) for t in tracers] avals_in = [t.aval for t in tracers_in] avals_out = abstract_eval_rulesprimitive tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None) for aval in avals_out] eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out, map(ref, tracers_out)) for t in tracers_out: t.recipe = eqn return tracers_out partial_eval_rules = {}
现在我们可以用PartialEvalTrace
构建 jaxprs 的图形表示,我们需要一种机制将图形表示转换为标准的 jaxpr。jaxpr 对应于图形的拓扑排序。
def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer], tracers_out: list[PartialEvalTracer]): tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval)) for t in tracers_in} constvar_to_val: dict[int, Any] = {} constid_to_var: dict[int, Var] = {} processed_eqns: set[int] = set() eqns: list[JaxprEqn] = [] for t in toposort(tracers_out, tracer_parents): if isinstance(t.recipe, LambdaBindingRecipe): assert id(t) in set(map(id, tracers_in)) elif isinstance(t.recipe, ConstRecipe): val = t.recipe.val var = constid_to_var.get(id(val)) if var is None: aval = raise_to_shaped(get_aval(val)) var = constid_to_var[id(val)] = Var(aval) constvar_to_val[var] = val tracer_to_var[id(t)] = var elif isinstance(t.recipe, JaxprEqnRecipe): if id(t.recipe) not in processed_eqns: eqns.append(recipe_to_eqn(tracer_to_var, t.recipe)) processed_eqns.add(id(t.recipe)) else: raise TypeError(t.recipe) constvars, constvals = unzip2(constvar_to_val.items()) in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in] out_vars = [tracer_to_var[id(t)] for t in tracers_out] jaxpr = Jaxpr(in_binders, eqns, out_vars) typecheck_jaxpr(jaxpr) return jaxpr, constvals def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe ) -> JaxprEqn: inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in] out_binders = [Var(aval) for aval in recipe.avals_out] for t_ref, var in zip(recipe.tracer_refs_out, out_binders): if t_ref() is not None: tracer_to_var[id(t_ref())] = var return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders) def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]: return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
显示代码单元源代码 隐藏代码单元源代码
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]): if not out_nodes: return [] out_nodes = remove_duplicates(out_nodes) child_counts = {} stack = list(out_nodes) while stack: node = stack.pop() if id(node) in child_counts: child_counts[id(node)] += 1 else: child_counts[id(node)] = 1 stack.extend(parents(node)) for node in out_nodes: child_counts[id(node)] -= 1 sorted_nodes = [] childless_nodes = [node for node in out_nodes if not child_counts[id(node)]] while childless_nodes: node = childless_nodes.pop() sorted_nodes.append(node) for parent in parents(node): if child_counts[id(parent)] == 1: childless_nodes.append(parent) else: child_counts[id(parent)] -= 1 sorted_nodes = sorted_nodes[::-1] check_toposort(sorted_nodes, parents) return sorted_nodes def remove_duplicates(lst): seen = set() return [x for x in lst if id(x) not in seen and not seen.add(id(x))] def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]): seen = set() for node in nodes: assert all(id(parent) in seen for parent in parents(node)) seen.add(id(node)) ```</details> 现在我们可以进行线性化了! ```py y, sin_lin = linearize(sin, 3.) print(y, sin(3.)) print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672 -0.9899924966004454 -0.9899924966004454
要处理linearize
-of-jit
,我们仍然需要为xla_call_p
编写部分评估规则。除了追踪器的记账外,主要任务是对 jaxpr 执行部分评估,将其“解压”为两个 jaxpr。
实际上有两个规则需要编写:一个是跟踪时间部分评估的规则,我们将其称为xla_call_partial_eval
,另一个是 jaxprs 的部分评估规则,我们将其称为xla_call_peval_eqn
。
def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts): del num_consts # Unused in_unknowns = [not t.pval.is_known for t in tracers] jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns) known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) known_vals = [t.pval.const for t in known_tracers] outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0) outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res) res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res] outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None) for v in jaxpr2.outs] eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers, dict(jaxpr=jaxpr2, num_consts=0), [v.aval for v in jaxpr2.outs], map(ref, outs2)) for t in outs2: t.recipe = eqn return merge_lists(out_unknowns, outs1, outs2) partial_eval_rules[xla_call_p] = xla_call_partial_eval def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool], instantiate: Optional[list[bool]] = None, ) -> tuple[Jaxpr, Jaxpr, list[bool], int]: env: dict[Var, bool] = {} residuals: set[Var] = set() def read(x: Atom) -> bool: return type(x) is Var and env[x] def write(unk: bool, v: Var) -> None: env[v] = unk def new_res(x: Atom) -> Atom: if type(x) is Var: residuals.add(x) return x eqns1, eqns2 = [], [] map(write, in_unknowns, jaxpr.in_binders) for eqn in jaxpr.eqns: unks_in = map(read, eqn.inputs) rule = partial_eval_jaxpr_rules.get(eqn.primitive) if rule: eqn1, eqn2, unks_out, res = rule(unks_in, eqn) eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res) map(write, unks_out, eqn.out_binders) elif any(unks_in): inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)] eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders)) map(partial(write, True), eqn.out_binders) else: eqns1.append(eqn) map(partial(write, False), eqn.out_binders) out_unknowns = map(read, jaxpr.outs) if instantiate is not None: for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate): if inst and not uk: new_res(v) out_unknowns = map(op.or_, out_unknowns, instantiate) residuals, num_res = list(residuals), len(residuals) assert all(type(v) is Var for v in residuals), residuals ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders) outs1, outs2 = partition_list(out_unknowns, jaxpr.outs) jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals) jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2) typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2) return jaxpr1, jaxpr2, out_unknowns, num_res def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2): jaxprty = typecheck_jaxpr(jaxpr) # (a1, a2) -> (b1, b2 ) jaxpr1ty = typecheck_jaxpr(jaxpr1) # a1 -> (b1, res) jaxpr2ty = typecheck_jaxpr(jaxpr2) # (res, a2) -> b2 a1, a2 = partition_list(unks_in, jaxprty.in_types) b1, b2 = partition_list(unks_out, jaxprty.out_types) b1_, res = split_list(jaxpr1ty.out_types, len(b1)) res_, a2_ = split_list(jaxpr2ty.in_types, len(res)) b2_ = jaxpr2ty.out_types if jaxpr1ty.in_types != a1: raise TypeError if jaxpr2ty.out_types != b2: raise TypeError if b1 != b1_: raise TypeError if res != res_: raise TypeError if a2 != a2_: raise TypeError if b2 != b2_: raise TypeError partial_eval_jaxpr_rules = {} def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn, ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]: jaxpr = eqn.params['jaxpr'] jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in) ins1, ins2 = partition_list(unks_in, eqn.inputs) out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders) residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]] eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0), out_binders1 + residuals) eqn2 = JaxprEqn(xla_call_p, residuals + ins2, dict(jaxpr=jaxpr2, num_consts=0), out_binders2) return eqn1, eqn2, unks_out, residuals partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn
通过这样,我们可以随心所欲地组合linearize
和jit
:
@jit def f(x): y = sin(x) * 2. z = - y + x return z y, f_lin = linearize(f, 3.) y_dot = f_lin(1.) print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit def f(x): y = sin(x) * 2. z = g(x, y) return z @jit def g(x, y): return cos(x) + y y, f_lin = linearize(f, 3.) y_dot = f_lin(1.) print(y, y_dot)
-0.7077524804807109 -2.121105001260758
vjp
和grad
vjp
变换的工作方式与线性化非常相似。其类型签名类似:
linearize : (a -> b) -> a -> (b, T a -o T b) vjp : (a -> b) -> a -> (b, T b -o T a)
唯一的区别在于,我们在返回之前转置计算的线性部分,以便从类型T a -o T b
变为类型T b -o T a
。也就是说,我们将vjp
实现为以下内容:
def vjp(f, x): y, f_lin = linearize(f, x) f_vjp = lambda y_bar: transpose(f_lin)(y_bar) return y, f_vjp
由于我们将线性计算作为 jaxpr,而不仅仅是 Python 可调用的函数,因此我们可以将转置转换实现为 jaxpr 解释器。
def vjp_flat(f, *primals_in): pvals_in = ([PartialVal.known(x) for x in primals_in] + [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in]) primal_pvals_in, tangent_pvals_in = split_half(pvals_in) def f_jvp(*primals_tangents_in): primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in)) return [*primals_out, *tangents_out] jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in) # linearize primal_pvals, _ = split_half(pvals_out) assert all(pval.is_known for pval in primal_pvals) primals_out = [pval.const for pval in primal_pvals] transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in] f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts) return primals_out, f_vjp def vjp(f, *primals_in): primals_in_flat, in_tree = tree_flatten(primals_in) f, out_tree = flatten_fun(f, in_tree) primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat) primals_out = tree_unflatten(out_tree(), primals_out_flat) def f_vjp(*cotangents_out): cotangents_out_flat, _ = tree_flatten(cotangents_out) cotangents_in_flat = f_vjp_flat(*cotangents_out_flat) return tree_unflatten(in_tree, cotangents_in_flat) return primals_out, f_vjp class UndefPrimal(NamedTuple): aval: ShapedArray register_pytree_node(UndefPrimal, lambda u: (u.aval, ()), lambda aval, _: UndefPrimal(aval))
我们使用UndefPrimal
实例来指示我们希望进行转置的参数。这是因为通常情况下,我们想要明确关闭值,我们希望将类型为a -> b -o c
的函数转置为类型为a -> c -o b
的函数。更一般地说,与函数线性相关的输入可能分散在参数列表中。因此,我们使用UndefPrimal
指示线性位置。我们将UndefPrimal
注册为一个 pytree 节点,因为 pytree 机制提供了一种方便的方法来从参数列表中剪除这些占位符。
接下来,我们可以编写eval_jaxpr_transposed
,以及对所有至少可以线性的原语编写转置规则:
# NB: the analogous function in JAX is called 'backward_pass' def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any] ) -> list[Any]: primal_env: dict[Var, Any] = {} ct_env: dict[Var, Any] = {} def read_primal(x: Atom) -> Any: return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val def write_primal(v: Var, val: Any) -> None: if type(val) is not UndefPrimal: primal_env[v] = val def read_cotangent(v: Var) -> Any: return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype)) def write_cotangent(x: Atom, val: Any): if type(x) is Var and val is not None: ct_env[x] = add(ct_env[x], val) if x in ct_env else val map(write_primal, jaxpr.in_binders, args) map(write_cotangent, jaxpr.outs, cotangents) for eqn in jaxpr.eqns[::-1]: primals_in = map(read_primal, eqn.inputs) cts_in = map(read_cotangent, eqn.out_binders) rule = transpose_rules[eqn.primitive] cts_out = rule(cts_in, *primals_in, **eqn.params) map(write_cotangent, eqn.inputs, cts_out) return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args) if type(x) is UndefPrimal] transpose_rules = {}
def mul_transpose_rule(cts, x, y): z_bar, = cts assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal) return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)] transpose_rules[mul_p] = mul_transpose_rule def neg_transpose_rule(cts, x): ybar, = cts assert type(x) is UndefPrimal return [neg(ybar)] transpose_rules[neg_p] = neg_transpose_rule def add_transpose_rule(cts, x, y): z_bar, = cts return [z_bar, z_bar] transpose_rules[add_p] = add_transpose_rule def reduce_sum_transpose_rule(cts, x, *, axis): y_bar, = cts return [broadcast(y_bar, x.aval.shape, axis)] transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts): del num_consts # Unused undef_primals = [type(x) is UndefPrimal for x in invals] transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals)) residuals, _ = partition_list(undef_primals, invals) outs = bind(xla_call_p, *new_consts, *residuals, *cts, jaxpr=transposed_jaxpr, num_consts=len(new_consts)) outs = iter(outs) return [next(outs) if undef else None for undef in undef_primals] transpose_rules[xla_call_p] = xla_call_transpose_rule @lru_cache() def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...] ) -> tuple[Jaxpr, list[Any]]: avals_in, avals_out = typecheck_jaxpr(jaxpr) traceable = partial(eval_jaxpr_transposed, jaxpr) args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)] trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out)) typecheck_jaxpr(trans_jaxpr) return trans_jaxpr, consts
现在我们可以进行线性化和转置,最后我们可以编写grad
:
def grad(f): def gradfun(x, *xs): y, f_vjp = vjp(f, x, *xs) if np.shape(y) != (): raise TypeError x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y))) return x_bar return gradfun
y, f_vjp = vjp(sin, 3.) print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x): y = sin(x) * 2. z = - y + x return z print(grad(f)(3.))
2.979984993200891
@jit def f(x): y = x * 2. z = g(y) return z @jit def g(x): return cos(x) * 2. print(grad(f)(3.))
1.1176619927957034
这里是一个组合性压力测试:
# from core_test.py fun_with_nested_calls_2 def foo(x): @jit def bar(y): def baz(w): q = jit(lambda x: y)(x) q = q + jit(lambda: y)() q = q + jit(lambda y: w + y)(y) q = jit(lambda w: jit(sin)(x) * y)(1.0) + q return q p, t = jvp(baz, (x + 1.0,), (y,)) return t + (x * p) return bar(x) def assert_allclose(*vals): for v1, v2 in zip(vals[:-1], vals[1:]): np.testing.assert_allclose(v1, v2) ans1 = f(3.) ans2 = jit(f)(3.) ans3, _ = jvp(f, (3.,), (5.,)) ans4, _ = jvp(jit(f), (3.,), (5.,)) assert_allclose(ans1, ans2, ans3, ans4) deriv1 = grad(f)(3.) deriv2 = grad(jit(f))(3.) deriv3 = jit(grad(jit(f)))(3.) _, deriv4 = jvp(f, (3.,), (1.,)) _, deriv5 = jvp(jit(f), (3.,), (1.,)) assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5) hess1 = grad(grad(f))(3.) hess2 = grad(grad(jit(f)))(3.) hess3 = grad(jit(grad(f)))(3.) hess4 = jit(grad(grad(f)))(3.) _, hess5 = jvp(grad(f), (3.,), (1.,)) _, hess6 = jvp(jit(grad(f)), (3.,), (1.,)) _, hess7 = jvp(jit(grad(f)), (3.,), (1.,)) assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)
JAX 中文文档(十)(5)https://developer.aliyun.com/article/1559711