JAX 中文文档(十)(4)

简介: JAX 中文文档(十)

JAX 中文文档(十)(3)https://developer.aliyun.com/article/1559709


第三部分:jit,简化

虽然jit具有类似于转换的 API,因为它接受 Python 可调用对象作为参数,但在幕后它实际上是一个高阶原语,而不是转换。当参数化为函数时,一个原语是高阶的。

即时(“final style”)和分阶段(“initial style”)处理

处理高阶原语有两个选择。每种选择都需要不同的跟踪方法,并产生不同的权衡:

  1. 即时处理,在bind将 Python 可调用对象作为参数。 我们推迟形成 jaxpr,直到可能的最后一刻,即在解释器栈底部运行最终解释器时。这样我们可以在解释器栈底部换上一个JaxprTrace,从而分阶段而不是执行所有原始操作。采用这种方法,堆栈中的转换会在我们像往常一样执行  Python 可调用对象时应用。这种方法实现起来可能非常棘手,但尽可能通用,因为它允许高阶原语不提升其参数的抽象级别,从而允许数据相关的  Python 控制流。我们称之为使用“最终风格高阶原语”,采用了迄今为止使用的“追踪时排除”最终风格变换。
  2. 分阶段处理,在bind将 jaxpr 作为参数。 在我们调用bind之前,在原始包装器中我们可以直接使用make_jaxpr来预先形成 jaxpr 并完全结束 Python 可调用对象。在这种情况下,make_jaxpr将其JaxprTrace放在解释器栈的顶部,并且没有低于堆栈的变换会通过闭合的  Tracer 输入到我们追踪的 Python 可调用对象中。 (在 Python 可调用对象内部应用的转换会像往常一样应用,被添加到  JaxprTrace 之上的堆栈中。)相反,堆栈中较低的转换稍后将应用于调用原始操作,并且调用原始操作的规则必须然后转换 jaxpr  本身。由于我们预先追踪到一个 jaxpr,这种方法不能支持数据相关的 Python  控制流,但它实现起来更为直接。我们将这种类型的高阶原语称为“初始风格高阶原语”,并说其 jaxpr 处理转换规则是“初始风格变换规则”。

后一种方法适用于jit,因为我们不需要支持用户提供的 Python 可调用对象中的数据相关 Python 控制流,因为jit的整个目的是将计算从 Python 阶段出来以供 XLA 执行。(相反,custom_jvp是一个高阶原语,我们希望在其中支持数据相关的 Python 控制流。)

在阅读了类型标签最终解释器论文后,我们从历史上开始使用“初始风格”和“最终风格”术语,并开玩笑称  JAX 是“未类型化的标签满足最终解释器”的实现。我们并不声称传承(或理解)这些术语背后的任何深层含义;我们宽泛地使用“初始风格”来表示“构建  AST 然后转换它”,并且我们使用“最终风格”来表示“追踪时转换”。但这只是不精确但易记的行话。

使用初始风格方法,这里是用户界面的jit包装器:

def jit(f):
  def f_jitted(*args):
    avals_in = [raise_to_shaped(get_aval(x)) for x in args]
    jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
    outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
    return tree_unflatten(out_tree, outs)
  return f_jitted
xla_call_p = Primitive('xla_call') 

对于任何新的原语,我们都需要为其提供转换规则,从其评估规则开始。当我们评估xla_call原语的应用时,我们希望将计算分阶段到 XLA。这涉及将 jaxpr 转换为 XLA HLO 程序,将参数值传输到 XLA 设备,执行 XLA 程序,并将结果传输回来。我们将缓存 XLA HLO 编译,以便于每个jit函数只需在参数形状和 dtype 签名上执行一次。

首先,一些实用工具。

class IDHashable:
  val: Any
  def __init__(self, val):
    self.val = val
  def __hash__(self) -> int:
    return id(self.val)
  def __eq__(self, other):
    return type(other) is IDHashable and id(self.val) == id(other.val) 

接下来,我们将为xla_call定义评估规则:

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops
def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
  consts, args = args[:num_consts], args[num_consts:]
  hashable_consts = tuple(map(IDHashable, consts))
  execute = xla_callable(IDHashable(jaxpr), hashable_consts)
  return execute(*args)
impl_rules[xla_call_p] = xla_call_impl
@lru_cache()
def xla_callable(hashable_jaxpr: IDHashable,
                 hashable_consts: tuple[IDHashable, ...]):
  jaxpr: Jaxpr = hashable_jaxpr.val
  typecheck_jaxpr(jaxpr)
  consts = [x.val for x in hashable_consts]
  in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
  c = xc.XlaBuilder('xla_call')
  xla_consts = _xla_consts(c, consts)
  xla_params = _xla_params(c, in_avals)
  outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
  out = xops.Tuple(c, outs)
  compiled = xb.get_backend(None).compile(
    xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
  return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])
def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:
  unique_consts = {id(cnst): cnst for cnst in consts}
  xla_consts = {
      id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}
  return [xla_consts[id(cnst)] for cnst in consts]
def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:
  return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]
def _xla_shape(aval: ShapedArray) -> xe.Shape:
  return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) 

主要操作在xla_callable中进行,它使用jaxpr_subcomp将 jaxpr 编译成 XLA HLO 程序,然后返回一个可调用对象来执行编译后的程序:

def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]
                  ) -> list[xe.XlaOp]:
  env: dict[Var, xe.XlaOp] = {}
  def read(x: Atom) -> xe.XlaOp:
    return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))
  def write(v: Var, val: xe.XlaOp) -> None:
    env[v] = val
  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_avals = [x.aval for x in eqn.inputs]
    in_vals = map(read, eqn.inputs)
    rule = xla_translations[eqn.primitive]
    out_vals = rule(c, in_avals, in_vals, **eqn.params)
    map(write, eqn.out_binders, out_vals)
  return map(read, jaxpr.outs)
def execute_compiled(compiled, out_avals, *args):
  input_bufs = input_handlers[type(x) for x in args]
  out_bufs = compiled.execute(input_bufs)
  return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]
default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
                  [bool, int, float, np.ndarray, np.float64, np.float32]}
def handle_result(aval: ShapedArray, buf):
  del aval  # Unused for now
  return np.asarray(buf)
xla_translations = {} 

请注意,jaxpr_subcomp具有简单解释器的结构。这是一个常见模式:我们处理 jaxprs 的方式通常是使用解释器。与任何解释器一样,我们需要为每个原语定义一个解释规则:

def direct_translation(op, c, in_avals, in_vals):
  del c, in_avals
  return [op(*in_vals)]
xla_translations[add_p] = partial(direct_translation, xops.Add)
xla_translations[mul_p] = partial(direct_translation, xops.Mul)
xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)
def reduce_sum_translation(c, in_avals, in_vals, *, axis):
  (x_aval,), (x,) = in_avals, in_vals
  zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
  subc = xc.XlaBuilder('add')
  shape = _xla_shape(ShapedArray((), x_aval.dtype))
  xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
  return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation
def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
  x, = in_vals
  dims_complement = [i for i in range(len(shape)) if i not in axes]
  return [xops.BroadcastInDim(x, shape, dims_complement)]
xla_translations[broadcast_p] = broadcast_translation 

有了这个,我们现在可以使用jit来分阶段、编译和执行 XLA 程序了!

@jit
def f(x, y):
  print('tracing!')
  return sin(x) * cos(y) 
z = f(3., 4.)  # 'tracing!' prints the first time
print(z) 
tracing!
-0.09224219304455371 
z = f(4., 5.)  # 'tracing!' doesn't print, compilation cache hit!
print(z) 
-0.21467624978306993 
@jit
def f(x):
  return reduce_sum(x, axis=0)
print(f(np.array([1., 2., 3.]))) 
6.0 
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
def deriv(f):
  return lambda x: jvp(f, (x,), (1.,))[1]
print(    deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.)) 
0.2822400161197344
0.2822400161197344 

而不是实现jit以首先对 jaxpr 进行跟踪,然后将 jaxpr 降低到 XLA HLO,我们可能看起来可以跳过 jaxpr 步骤,而在跟踪时直接降低到 HLO。也就是说,也许我们可以用一个TraceTracer实现jit,在每个原语绑定时逐步追加到 XLA HLO 图中。目前这样做是正确的,但当我们引入编译的 SPMD 计算时,就不可能了,因为在编译程序之前我们必须知道所需的副本数量。

我们尚未为xla_call_p定义任何转换规则,除了其评估规则。也就是说,我们尚不能做vmap-of-jitjvp-of-jit甚至jit-of-jit。相反,jit必须处于“顶层”。让我们来修复这个问题!

def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  n = len(outs) // 2
  primals_out, tangents_out = outs[:n], outs[n:]
  return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule
@lru_cache()
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
  def jvp_traceable(*primals_and_tangents):
    n = len(primals_and_tangents) // 2
    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)
  in_avals = [v.aval for v in jaxpr.in_binders]
  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
  return new_jaxpr, new_consts 
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
  del num_consts  # Unused
  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule
@lru_cache()
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
               ) -> tuple[Jaxpr, list[Any]]:
  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
  in_avals = [unmapped_aval(axis_size, d, v.aval)
              for v, d in zip(jaxpr.in_binders, bdims_in)]
  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
  return new_jaxpr, new_consts
def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
                  ) -> ShapedArray:
  if batch_dim is not_mapped:
    return aval
  else:
    shape = list(aval.shape)
    shape.insert(batch_dim, axis_size)
    return ShapedArray(tuple(shape), aval.dtype) 
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
  del num_consts  # Unused
  jaxpr_type = typecheck_jaxpr(jaxpr)
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule
def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
  del num_consts  # Only used at top-level.
  # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.
  subc = xc.XlaBuilder('inner xla_call')
  xla_params = _xla_params(subc, in_avals)
  outs = jaxpr_subcomp(subc, jaxpr, xla_params)
  subc = subc.build(xops.Tuple(subc, outs))
  return destructure_tuple(c, xops.Call(c, subc, in_vals))
xla_translations[xla_call_p] = xla_call_translation
def destructure_tuple(c, tup):
  num_elements = len(c.get_shape(tup).tuple_shapes())
  return [xops.GetTupleElement(tup, i) for i in range(num_elements)] 
@jit
def f(x):
  print('tracing!')
  y = sin(x) * 2.
  z = - y + x
  return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot) 
tracing!
2.7177599838802657
2.979984993200891 
y, ydot = jvp(f, (x,), (xdot,))  # 'tracing!' not printed 
ys = vmap(f, (0,))(np.arange(3.))
print(ys) 
[ 0\.         -0.68294197  0.18140515] 

一个遗漏的部分是数组的设备内存持久性。也就是说,我们已经定义了handle_result以将结果作为 NumPy 数组传输回 CPU 内存,但通常最好避免仅为了下一步操作而传输结果。我们可以通过引入Array类来实现这一点,它可以包装 XLA 缓冲区,同时鸭子类型numpy.ndarray

def handle_result(aval: ShapedArray, buf):  # noqa: F811
  return Array(aval, buf)
class Array:
  buf: Any
  aval: ShapedArray
  def __init__(self, aval, buf):
    self.aval = aval
    self.buf = buf
  dtype = property(lambda self: self.aval.dtype)
  shape = property(lambda self: self.aval.shape)
  ndim  = property(lambda self: self.aval.ndim)
  def __array__(self): return np.asarray(self.buf)
  def __repr__(self):  return repr(np.asarray(self.buf))
  def __str__(self):   return str(np.asarray(self.buf))
  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(add)
  _mul = staticmethod(mul)
  _rmul = staticmethod(mul)
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf
jax_types.add(Array) 
@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot) 
2.7177599838802657
2.979984993200891 

显示代码单元格源码 隐藏代码单元格源码

def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
  rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call 
```</details>
## 第四部分:`linearize`和`vjp`(以及`grad`!)
`linearize`和`vjp`的自动微分函数都建立在`jvp`之上,但也涉及 jaxprs。这是因为两者都涉及分阶段或延迟计算。
### `linearize`
对于`linearize`的情况,我们想要分离出`jvp`计算的线性部分。也就是说,用[Haskell 类型签名](https://wiki.haskell.org/Type_signature)来说,如果我们有`jvp : (a -> b) -> (a, T a) -> (b, T b)`,那么我们会写成`linearize : (a -> b) -> a -> (b, T a -o T b)`,使用`T a`表示“`a`的切线类型”,并使用“棒棒糖”`-o`而不是箭头`->`来指示一个*线性*函数。我们也是以`jvp`的语义来定义`linearize`:
```py
y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot) 

对于(y, y_dot),与原先相同的结果如下:

y, y_dot = jvp(f, (x,), (x_dot,)) 

在应用f_lin时,不会重新执行任何线性化工作。我们将延迟的线性部分f_lin : T a -o T b表示为一个 jaxpr。

顺便说一句,既然我们有了线性箭头-o,我们可以为jvp提供一个稍微更详细的类型:

jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b) 

我们在这里编写UnrestrictedUse只是为了表明我们有一个特殊的对,第一个元素可以以非线性的方式使用。与线性箭头结合使用时,此符号只是用来表示函数jvp f以非线性方式使用其第一个输入,但以线性方式使用其第二个输入,生成相应的非线性输出(可以以非线性方式使用),与线性输出配对。这种更精细的类型签名编码了jvp f中的数据依赖关系,对于部分评估非常有用。

要从 JVP 构建f_lin的 jaxpr,我们需要执行部分评估:在我们追踪时评估所有原始值,但是将切线计算分阶段到一个 jaxpr 中。这是我们构建 jaxprs 的第二种方式。但是,与make_jaxpr及其基础的JaxprTrace/JaxprTracer解释器的目标是分阶段所有原始绑定不同,这第二种方法仅分阶段那些具有对切线输入的数据依赖性的原始绑定。

首先,一些实用工具:

def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
  assert not len(lst) % 2
  return split_list(lst, len(lst) // 2)
def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
  l1, l2 = iter(l1), iter(l2)
  out = [next(l2) if b else next(l1) for b in which]
  assert next(l1, None) is next(l2, None) is None
  return out 

接下来,我们将编写linearize,通过将jvp与一般的部分评估转换组合在一起:

def linearize_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
  return primals_out, f_lin
def linearize(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  def f_lin(*tangents_in):
    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
    if in_tree != in_tree2: raise TypeError
    tangents_out_flat = f_lin_flat(*tangents_in_flat)
    return tree_unflatten(out_tree(), tangents_out_flat)
  return primals_out, f_lin
def vspace(aval: ShapedArray) -> ShapedArray:
  return raise_to_shaped(aval)  # TODO handle integers? 

现在我们转向一般的部分评估转换。目标是接受一个 Python 可调用函数和一个输入列表,其中一些已知,一些未知,并产生(1)可以从已知输入计算出来的所有输出,以及(2)表示仅在其余输入已知后才能执行的 Python 可调用函数计算的 japxr。

这种转换很难用一个类型签名来总结。如果我们假设输入函数的类型签名是(a1, a2) -> (b1, b2),其中a1a2分别表示已知和未知的输入,并且其中b1仅对a1有数据依赖性,而b2a2有一些数据依赖性,那么我们可能会写成:

partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2) 

简言之,给定类型为 a1 的输入值,partial_eval 将生成类型为 b1 的输出值以及代表在第二阶段完成计算所需的存在量化类型 r 的“残余”值。它还会生成一个类型为 (r, a2) -> b2 的函数,该函数接受残余值以及剩余输入,并生成剩余输出。

我们喜欢将部分评估视为将一个计算“解压”为两个的过程。例如,考虑以下 jaxpr

{ lambda a:float64[] .
  let b:float64[] = sin a
      c:float64[] = neg b
  in ( c ) } 

JVP 的 jaxpr 将如下所示:

{ lambda a:float64[] b:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      e:float64[] = mul d b
      f:float64[] = neg c
      g:float64[] = neg e
  in ( f, g ) } 

如果我们想象将部分评估应用于此 jaxpr,第一个输入已知,第二个输入未知,我们将 JVPjaxpr “解压”为原始和切线 jaxpr

{ lambda a:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      f:float64[] = neg c
  in ( f, d ) } 
{ lambda d:float64[] b:float64[] .
  let e:float64[] = mul d b
      g:float64[] = neg e
  in ( g ) } 

这第二个 jaxpr 表示我们从 linearize 中希望得到的线性计算。

然而,与此 jaxpr 示例不同的是,我们希望在评估输入的 Python 可调用函数时,对已知值进行计算。换句话说,我们不想在整个函数 (a1, a2) -> (b1, b2)jaxpr 中首先将所有操作分离出 Python,然后再确定哪些可以立即评估,哪些必须延迟。我们只想形成那些由于依赖于未知输入而必须延迟的操作的 jaxpr。在自动微分的背景下,这正是使我们能够处理诸如 grad(lambda x: x**2 if x > 0 else 0.) 函数的特性。Python 控制流能够正常工作,因为部分评估保持了 Python 中的原始计算。因此,我们的 TraceTracer 子类必须动态地分辨出哪些可以评估,哪些必须分离到 jaxpr 中。

首先,我们从 PartialVal 类开始,它表示可以是已知或未知的值:

class PartialVal(NamedTuple):
  aval: ShapedArray
  const: Optional[Any]
  @classmethod
  def known(cls, val: Any):
    return PartialVal(get_aval(val), val)
  @classmethod
  def unknown(cls, aval: ShapedArray):
    return PartialVal(aval, None)
  is_known   = property(lambda self: self.const is not None)
  is_unknown = property(lambda self: self.const is     None) 

部分评估将接受一个表示输入的 PartialVal 列表,并返回一个表示延迟计算的 jaxprPartialVal 输出列表:

def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
                      ) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
  with new_main(PartialEvalTrace) as main:
    trace = PartialEvalTrace(main)
    tracers_in = [trace.new_arg(pval) for pval in pvals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    pvals_out = [t.pval for t in tracers_out]
    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]
    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
  return jaxpr, pvals_out, consts 

接下来,我们需要实现 PartialEvalTrace 及其 PartialEvalTracer。此解释器将在跟踪数据依赖关系的同时动态构建 jaxpr。为此,它在 PartialEvalTracer 节点(代表分阶段的值)和 JaxprRecipe 节点(代表如何从其他值计算某些值的公式)之间建立了一个二部有向无环图(DAG)。一种类型的配方是 JaxprEqnRecipe,对应于 JaxprEqn 的原语应用,但我们还有常量和 Lambda 绑定器的配方类型:

from weakref import ref, ReferenceType
class LambdaBindingRecipe(NamedTuple):
  pass
class ConstRecipe(NamedTuple):
  val: Any
class JaxprEqnRecipe(NamedTuple):
  prim: Primitive
  tracers_in: list['PartialEvalTracer']
  params: dict[str, Any]
  avals_out: list[ShapedArray]
  tracer_refs_out: list['ReferenceType[PartialEvalTracer]']
JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe] 
class PartialEvalTracer(Tracer):
  pval: PartialVal
  recipe: Optional[JaxprRecipe]
  def __init__(self, trace, pval, recipe):
    self._trace = trace
    self.pval = pval
    self.recipe = recipe
  aval = property(lambda self: self.pval.aval)
  def full_lower(self):
    if self.pval.is_known:
      return full_lower(self.pval.const)
    return self 

PartialEvalTrace 包含构建 JaxprRecipePartialEvalTracer 图形的逻辑。每个参数对应于 LambdaBindingRecipe 叶节点,每个常量都是一个 ConstRecipe 叶节点,保存对常量的引用。所有其他跟踪器和配方都来自 process_primitive,它使用 JaxprEqnRecipe 形成具有 JaxprEqn 的原语应用的跟踪器。

对于大多数原语,process_primitive逻辑很简单:如果所有输入都已知,我们可以在已知值上绑定原语(在 Python 中评估它),并避免形成对应于输出的追踪器。如果任何输入未知,则我们转而进行JaxprEqnRecipe的阶段输出,表示原语应用。为了构建代表未知输出的追踪器,我们需要 aval,这些 aval 来自抽象评估规则。(请注意,追踪器引用JaxprEqnRecipe,而JaxprEqnRecipe引用追踪器;我们通过使用弱引用来避免循环垃圾。)

process_primitive逻辑适用于大多数原语,但xla_call_p需要递归处理。因此,我们在partial_eval_rules字典中特别处理它的规则。

class PartialEvalTrace(Trace):
  def new_arg(self, pval: PartialVal) -> Any:
    return PartialEvalTracer(self, pval, LambdaBindingRecipe())
  def lift(self, val: Any) -> PartialEvalTracer:
    return PartialEvalTracer(self, PartialVal.known(val), None)
  pure = lift
  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
    if tracer.pval.is_unknown:
      return tracer
    else:
      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))
  def process_primitive(self, primitive, tracers, params):
    if all(t.pval.is_known for t in tracers):
      return bind(primitive, *map(full_lower, tracers), **params)
    rule = partial_eval_rules.get(primitive)
    if rule: return rule(self, tracers, **params)
    tracers_in = [self.instantiate_const(t) for t in tracers]
    avals_in = [t.aval for t in tracers_in]
    avals_out = abstract_eval_rulesprimitive
    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
                   for aval in avals_out]
    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
                         map(ref, tracers_out))
    for t in tracers_out: t.recipe = eqn
    return tracers_out
partial_eval_rules = {} 

现在我们可以用PartialEvalTrace构建 jaxprs 的图形表示,我们需要一种机制将图形表示转换为标准的 jaxpr。jaxpr 对应于图形的拓扑排序。

def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
                     tracers_out: list[PartialEvalTracer]):
  tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
                                   for t in tracers_in}
  constvar_to_val: dict[int, Any] = {}
  constid_to_var: dict[int, Var] = {}
  processed_eqns: set[int] = set()
  eqns: list[JaxprEqn] = []
  for t in toposort(tracers_out, tracer_parents):
    if isinstance(t.recipe, LambdaBindingRecipe):
      assert id(t) in set(map(id, tracers_in))
    elif isinstance(t.recipe, ConstRecipe):
      val = t.recipe.val
      var = constid_to_var.get(id(val))
      if var is None:
        aval = raise_to_shaped(get_aval(val))
        var = constid_to_var[id(val)] = Var(aval)
        constvar_to_val[var] = val
      tracer_to_var[id(t)] = var
    elif isinstance(t.recipe, JaxprEqnRecipe):
      if id(t.recipe) not in processed_eqns:
        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
        processed_eqns.add(id(t.recipe))
    else:
      raise TypeError(t.recipe)
  constvars, constvals = unzip2(constvar_to_val.items())
  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
  out_vars = [tracer_to_var[id(t)] for t in tracers_out]
  jaxpr = Jaxpr(in_binders, eqns, out_vars)
  typecheck_jaxpr(jaxpr)
  return jaxpr, constvals
def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
                  ) -> JaxprEqn:
  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
  out_binders = [Var(aval) for aval in recipe.avals_out]
  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
    if t_ref() is not None: tracer_to_var[id(t_ref())] = var
  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)
def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else [] 

显示代码单元源代码 隐藏代码单元源代码

def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
  if not out_nodes: return []
  out_nodes = remove_duplicates(out_nodes)
  child_counts = {}
  stack = list(out_nodes)
  while stack:
    node = stack.pop()
    if id(node) in child_counts:
      child_counts[id(node)] += 1
    else:
      child_counts[id(node)] = 1
      stack.extend(parents(node))
  for node in out_nodes:
    child_counts[id(node)] -= 1
  sorted_nodes = []
  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
  while childless_nodes:
    node = childless_nodes.pop()
    sorted_nodes.append(node)
    for parent in parents(node):
      if child_counts[id(parent)] == 1:
        childless_nodes.append(parent)
      else:
        child_counts[id(parent)] -= 1
  sorted_nodes = sorted_nodes[::-1]
  check_toposort(sorted_nodes, parents)
  return sorted_nodes
def remove_duplicates(lst):
  seen = set()
  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]
def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
  seen = set()
  for node in nodes:
    assert all(id(parent) in seen for parent in parents(node))
    seen.add(id(node)) 
```</details>
现在我们可以进行线性化了!
```py
y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.)) 
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454 

要处理linearize-of-jit,我们仍然需要为xla_call_p编写部分评估规则。除了追踪器的记账外,主要任务是对 jaxpr 执行部分评估,将其“解压”为两个 jaxpr。

实际上有两个规则需要编写:一个是跟踪时间部分评估的规则,我们将其称为xla_call_partial_eval,另一个是 jaxprs 的部分评估规则,我们将其称为xla_call_peval_eqn

def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
  del num_consts  # Unused
  in_unknowns = [not t.pval.is_known for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
                       dict(jaxpr=jaxpr2, num_consts=0),
                       [v.aval for v in jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval
def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
                       instantiate: Optional[list[bool]] = None,
                       ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
  env: dict[Var, bool] = {}
  residuals: set[Var] = set()
  def read(x: Atom) -> bool:
    return type(x) is Var and env[x]
  def write(unk: bool, v: Var) -> None:
    env[v] = unk
  def new_res(x: Atom) -> Atom:
    if type(x) is Var: residuals.add(x)
    return x
  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks_in = map(read, eqn.inputs)
    rule = partial_eval_jaxpr_rules.get(eqn.primitive)
    if rule:
      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
      map(write, unks_out, eqn.out_binders)
    elif any(unks_in):
      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
      map(partial(write, True), eqn.out_binders)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.out_binders)
  out_unknowns = map(read, jaxpr.outs)
  if instantiate is not None:
    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
      if inst and not uk: new_res(v)
    out_unknowns = map(op.or_, out_unknowns, instantiate)
  residuals, num_res = list(residuals), len(residuals)
  assert all(type(v) is Var for v in residuals), residuals
  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)
  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)
  return jaxpr1, jaxpr2, out_unknowns, num_res
def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
  jaxprty = typecheck_jaxpr(jaxpr)    # (a1,  a2) -> (b1, b2 )
  jaxpr1ty = typecheck_jaxpr(jaxpr1)  #  a1       -> (b1, res)
  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2
  a1, a2 = partition_list(unks_in, jaxprty.in_types)
  b1, b2 = partition_list(unks_out, jaxprty.out_types)
  b1_, res = split_list(jaxpr1ty.out_types, len(b1))
  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
  b2_ = jaxpr2ty.out_types
  if jaxpr1ty.in_types != a1: raise TypeError
  if jaxpr2ty.out_types != b2: raise TypeError
  if b1 != b1_: raise TypeError
  if res != res_: raise TypeError
  if a2 != a2_: raise TypeError
  if b2 != b2_: raise TypeError
partial_eval_jaxpr_rules = {}
def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                       ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
  jaxpr = eqn.params['jaxpr']
  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
  ins1, ins2 = partition_list(unks_in, eqn.inputs)
  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
                  out_binders1 + residuals)
  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
  return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn 

通过这样,我们可以随心所欲地组合linearizejit

@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot) 
2.7177599838802657 2.979984993200891 
@jit
def f(x):
  y = sin(x) * 2.
  z = g(x, y)
  return z
@jit
def g(x, y):
  return cos(x) + y
y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot) 
-0.7077524804807109 -2.121105001260758 

vjpgrad

vjp变换的工作方式与线性化非常相似。其类型签名类似:

linearize : (a -> b) -> a -> (b, T a -o T b)
vjp       : (a -> b) -> a -> (b, T b -o T a) 

唯一的区别在于,我们在返回之前转置计算的线性部分,以便从类型T a -o T b变为类型T b -o T a。也就是说,我们将vjp实现为以下内容:

def vjp(f, x):
  y, f_lin = linearize(f, x)
  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
  return y, f_vjp 

由于我们将线性计算作为 jaxpr,而不仅仅是 Python 可调用的函数,因此我们可以将转置转换实现为 jaxpr 解释器。

def vjp_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # linearize
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
  return primals_out, f_vjp
def vjp(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  def f_vjp(*cotangents_out):
    cotangents_out_flat, _ = tree_flatten(cotangents_out)
    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
    return tree_unflatten(in_tree, cotangents_in_flat)
  return primals_out, f_vjp
class UndefPrimal(NamedTuple):
  aval: ShapedArray
register_pytree_node(UndefPrimal,
                     lambda u: (u.aval, ()),
                     lambda aval, _: UndefPrimal(aval)) 

我们使用UndefPrimal实例来指示我们希望进行转置的参数。这是因为通常情况下,我们想要明确关闭值,我们希望将类型为a -> b -o c的函数转置为类型为a -> c -o b的函数。更一般地说,与函数线性相关的输入可能分散在参数列表中。因此,我们使用UndefPrimal指示线性位置。我们将UndefPrimal注册为一个 pytree 节点,因为 pytree 机制提供了一种方便的方法来从参数列表中剪除这些占位符。

接下来,我们可以编写eval_jaxpr_transposed,以及对所有至少可以线性的原语编写转置规则:

# NB: the analogous function in JAX is called 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
                          ) -> list[Any]:
  primal_env: dict[Var, Any] = {}
  ct_env: dict[Var, Any] = {}
  def read_primal(x: Atom) -> Any:
    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val
  def write_primal(v: Var, val: Any) -> None:
    if type(val) is not UndefPrimal:
      primal_env[v] = val
  def read_cotangent(v: Var) -> Any:
    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))
  def write_cotangent(x: Atom, val: Any):
    if type(x) is Var and val is not None:
      ct_env[x] = add(ct_env[x], val) if x in ct_env else val
  map(write_primal, jaxpr.in_binders, args)
  map(write_cotangent, jaxpr.outs, cotangents)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.inputs)
    cts_in = map(read_cotangent, eqn.out_binders)
    rule = transpose_rules[eqn.primitive]
    cts_out = rule(cts_in, *primals_in, **eqn.params)
    map(write_cotangent, eqn.inputs, cts_out)
  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
          if type(x) is UndefPrimal]
transpose_rules = {} 
def mul_transpose_rule(cts, x, y):
  z_bar, = cts
  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule
def neg_transpose_rule(cts, x):
  ybar, = cts
  assert type(x) is UndefPrimal
  return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule
def add_transpose_rule(cts, x, y):
  z_bar, = cts
  return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule
def reduce_sum_transpose_rule(cts, x, *, axis):
  y_bar, = cts
  return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule
def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
  del num_consts  # Unused
  undef_primals = [type(x) is UndefPrimal for x in invals]
  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
  residuals, _ = partition_list(undef_primals, invals)
  outs = bind(xla_call_p, *new_consts, *residuals, *cts,
              jaxpr=transposed_jaxpr, num_consts=len(new_consts))
  outs = iter(outs)
  return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule
@lru_cache()
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
                    ) -> tuple[Jaxpr, list[Any]]:
  avals_in, avals_out = typecheck_jaxpr(jaxpr)
  traceable = partial(eval_jaxpr_transposed, jaxpr)
  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
  typecheck_jaxpr(trans_jaxpr)
  return trans_jaxpr, consts 

现在我们可以进行线性化和转置,最后我们可以编写grad

def grad(f):
  def gradfun(x, *xs):
    y, f_vjp = vjp(f, x, *xs)
    if np.shape(y) != (): raise TypeError
    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
    return x_bar
  return gradfun 
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454 
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
print(grad(f)(3.)) 
2.979984993200891 
@jit
def f(x):
  y = x * 2.
  z = g(y)
  return z
@jit
def g(x):
  return cos(x) * 2.
print(grad(f)(3.)) 
1.1176619927957034 

这里是一个组合性压力测试:

# from core_test.py fun_with_nested_calls_2
def foo(x):
  @jit
  def bar(y):
    def baz(w):
      q = jit(lambda x: y)(x)
      q = q + jit(lambda: y)()
      q = q + jit(lambda y: w + y)(y)
      q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
      return q
    p, t = jvp(baz, (x + 1.0,), (y,))
    return t + (x * p)
  return bar(x)
def assert_allclose(*vals):
  for v1, v2 in zip(vals[:-1], vals[1:]):
    np.testing.assert_allclose(v1, v2)
ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)
deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)
hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7) 


JAX 中文文档(十)(5)https://developer.aliyun.com/article/1559711

相关文章
|
3月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
26 0
JAX 中文文档(六)(1)
|
3月前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
33 0
|
3月前
|
机器学习/深度学习
JAX 中文文档(六)(5)
JAX 中文文档(六)
24 0
|
3月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
34 0
|
3月前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
40 0
|
3月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
26 0
|
3月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
19 0
|
3月前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
23 0
|
3月前
|
机器学习/深度学习 程序员 编译器
JAX 中文文档(三)(1)
JAX 中文文档(三)
28 0
|
3月前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
53 0