JAX 中文文档(十)(2)

简介: JAX 中文文档(十)

JAX 中文文档(十)(1)https://developer.aliyun.com/article/1559707


Autodidax:从头开始学习 JAX 核心

原文:jax.readthedocs.io/en/latest/autodidax.html

你是否想过学习 JAX 是如何工作的,但实现看起来深奥无比?那么,你很幸运!通过阅读本教程,你将了解 JAX 核心系统中的每一个重要思想。你甚至将了解我们奇怪的行话!

这是一个正在进行中的草稿。 这里还缺少一些重要的部分,将在第五部分和第六部分(以及更多?)中添加。此外,这里还有一些尚未应用于主系统的简化,但我们会应用的。

第一部分:转换作为解释器:标准评估、jvpvmap

我们希望转换看起来像这样的函数:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z 

将函数如sin和作为中缀操作符底层的算术运算(muladdneg)视为原语操作,意味着它们是处理的原子单位而不是组合。

“Transform”意味着“以不同方式解释”。我们不再采用标准解释,其中我们将原语操作应用于数值输入以生成数值输出,而是想要重写原语应用,并让不同的值流过我们的程序。例如,我们可能希望用其 JVP 规则的应用替换每个原语的应用,并让原始-切线对流经我们的程序。此外,我们希望能够组合多个转换,形成解释器的堆栈。

JAX 核心机制

我们可以实现解释器的堆栈,甚至可以在执行要转换的 Python 函数时实时执行它们。首先,让我们定义这些原语,以便我们可以拦截它们的应用:

from typing import NamedTuple
class Primitive(NamedTuple):
  name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None:
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int:
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
  out, = bind(prim, *args, **params)
  return out 

我们稍后将设置数组数据类型和中缀操作方法。

一个Primitive只是一个带有名称的对象,我们附加了我们的解释规则(每个转换对应一个规则)。bind函数是我们的拦截点:它将根据参数在跟踪器中的封装方式以及活动的解释器来确定应用哪个转换规则。

用户代码调用的函数,如addsin,只是对bind调用的包装器。这些包装器允许我们控制参数如何传递给bind,特别是我们遵循一个方便的内部约定:当我们调用bind时,我们将表示数组数据的值作为位置参数传递,并通过关键字将元数据(如axis参数传递给sum_p)。这种调用约定简化了一些核心逻辑(因为例如下文将要定义的Tracer类的实例只能出现在bind的位置参数中)。这些包装器还可以提供文档字符串!

我们将活动解释器表示为堆栈。堆栈只是一个简单的list,每个元素是一个容器,具有整数级别(对应于元素在堆栈中的高度)、解释器类型(我们称之为trace_type)以及解释器需要的任何全局数据的可选字段。我们称每个元素为MainTrace,尽管“Interpreter”可能更加描述性。

from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any
class MainTrace(NamedTuple):
  level: int
  trace_type: type['Trace']
  global_data: Optional[Any]
trace_stack: list[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None  # to be employed in Part 3
@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data)
  trace_stack.append(main)
  try:
    yield main
  finally:
    trace_stack.pop() 

在我们准备应用变换时,我们将使用new_main将另一个解释器推送到堆栈上。然后,在函数中应用原语时,我们可以认为bind首先由堆栈顶部的追踪器解释(即具有最高级别的追踪器)。如果第一个解释器本身在其对于原语的解释规则中绑定其他原语,例如sin_p的 JVP 规则可能绑定cos_pmul_p,那么这些bind调用将由下一个级别的解释器处理。

解释器堆栈的底部放什么?在底部,我们知道所有变换解释器都已完成,我们只想进行标准评估。因此,在底部我们将放置一个评估解释器。

让我们概述一下解释器的接口,它基于TraceTracer基类。Tracer表示一个封装的值,可能携带一些由解释器使用的额外上下文数据。Trace处理将值封装到Tracer中,并且还处理原语应用。

class Trace:
  main: MainTrace
  def __init__(self, main: MainTrace) -> None:
    self.main = main
  def pure(self, val): assert False  # must override
  def lift(self, val): assert False  # must override
  def process_primitive(self, primitive, tracers, params):
    assert False  # must override 

前两种方法是关于在Tracer中封装值,Tracer是我们转换的 Python 程序中流动的对象。最后一种方法是我们将用于解释原始应用的回调。

Trace本身除了引用其对应的MainTrace实例之外并不包含任何数据。事实上,在应用变换过程中可能会创建和丢弃多个Trace实例,而每个应用变换只会创建一个MainTrace实例。

至于Tracer们本身,每个Tracer都携带一个抽象值(并将中缀运算符转发给它),其余由变换决定。(TracerAbstractValue之间的关系是每个变换对应一个Tracer,并且每个基本类型(如数组)至少有一个AbstractValue。)

import numpy as np
class Tracer:
  _trace: Trace
  __array_priority__ = 1000
  @property
  def aval(self):
    assert False  # must override
  def full_lower(self):
    return self  # default implementation
  def __neg__(self): return self.aval._neg(self)
  def __add__(self, other): return self.aval._add(self, other)
  def __radd__(self, other): return self.aval._radd(self, other)
  def __mul__(self, other): return self.aval._mul(self, other)
  def __rmul__(self, other): return self.aval._rmul(self, other)
  def __gt__(self, other): return self.aval._gt(self, other)
  def __lt__(self, other): return self.aval._lt(self, other)
  def __bool__(self): return self.aval._bool(self)
  def __nonzero__(self): return self.aval._nonzero(self)
  def __getattr__(self, name):
    try:
      return getattr(self.aval, name)
    except AttributeError:
      raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x) 
class ShapedArray:
  array_abstraction_level = 1
  shape: tuple[int, ...]
  dtype: np.dtype
  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype
  @property
  def ndim(self):
    return len(self.shape)
  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(swap(add))
  _mul = staticmethod(mul)
  _rmul = staticmethod(swap(mul))
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
  @staticmethod
  def _bool(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")
  @staticmethod
  def _nonzero(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")
  def str_short(self):
    return f'{self.dtype.name}[{",".join(str(d)  for  d  in  self.shape)}]'
  def __hash__(self):
    return hash((self.shape, self.dtype))
  def __eq__(self, other):
    return (type(self) is type(other) and
            self.shape == other.shape and self.dtype == other.dtype)
  def __repr__(self):
    return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
  array_abstraction_level = 2
  val: np.ndarray
  def __init__(self, val):
    self.val = val
    self.shape = val.shape
    self.dtype = val.dtype
  @staticmethod
  def _bool(tracer):
    return bool(tracer.aval.val)
  @staticmethod
  def _nonzero(tracer):
    return bool(tracer.aval.val)
def get_aval(x):
  if isinstance(x, Tracer):
    return x.aval
  elif type(x) in jax_types:
    return ConcreteArray(np.asarray(x))
  else:
    raise TypeError(x)
jax_types = {bool, int, float,
             np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray} 

注意,实际上我们为数组有两个AbstractValue,代表不同的抽象级别。ShapedArray代表具有给定形状和 dtype 的所有可能数组的集合。ConcreteArray代表一个由单个数组值组成的单例集。

现在我们已经设置了解释器堆栈、解释器的 Trace/Tracer API 和抽象值,我们可以回来实现bind了:

def bind(prim, *args, **params):
  top_trace = find_top_trace(args)
  tracers = [full_raise(top_trace, arg) for arg in args]
  outs = top_trace.process_primitive(prim, tracers, params)
  return [full_lower(out) for out in outs] 

主要的操作是我们调用find_top_trace来找出哪个解释器应该处理这个基元应用。然后我们调用该顶层跟踪的process_primitive,以便跟踪可以应用其解释规则。full_raise的调用只是确保输入封装在顶层跟踪的Tracer实例中,而对full_lower的调用是一个可选的优化,以便我们尽可能多地从Tracer中解封值。

import operator as op
def find_top_trace(xs) -> Trace:
  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
                 default=trace_stack[0], key=op.attrgetter('level'))
  if dynamic_trace and dynamic_trace.level > top_main.level:
    top_main = dynamic_trace
  return top_main.trace_type(top_main) 

换句话说,忽略dynamic_trace步骤直到第三部分,find_top_trace返回与其输入上的Tracer相关联的最高级解释器,并且否则返回堆栈底部的解释器(至少目前总是一个求值跟踪)。这与上面的描述有所偏离,我们总是从运行堆栈顶部的解释器开始,然后逐级向下工作,应用堆栈中的每个解释器。相反,我们只有在将输入参数传递给基元绑定的Tracer中时才应用解释器对应的解释器时才应用解释器。这种优化让我们可以跳过不相关的转换,但内置了一个假设,即转换大部分时候都遵循数据依赖性(除了特殊的堆栈底部解释器,它解释一切)。

另一种方法是使堆栈中的每个解释器都解释每个操作。值得探索!JAX 大部分是围绕数据依赖性而设计的,大部分原因是因为这对于自动微分来说非常自然,而 JAX 的根源在于自动微分。但也许会过拟合。

def full_lower(val: Any):
  if isinstance(val, Tracer):
    return val.full_lower()
  else:
    return val
def full_raise(trace: Trace, val: Any) -> Tracer:
  if not isinstance(val, Tracer):
    assert type(val) in jax_types
    return trace.pure(val)
  level = trace.main.level
  if val._trace.main is trace.main:
    return val
  elif val._trace.main.level < level:
    return trace.lift(val)
  elif val._trace.main.level > level:
    raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
  else:  # val._trace.level == level
    raise Exception(f"Different traces at same level: {val._trace}, {trace}.") 

full_raise中的逻辑用于将值封装在特定TraceTracer中,根据上下文对Trace调用不同的方法:对非Tracer常数调用Trace.pure,对已经来自低级解释器的Tracer调用Trace.lift。这两种方法可以共享相同的实现,但通过在核心逻辑中加以区分,我们可以向Trace子类提供更多信息。

JAX 核心就是这样!现在我们可以开始添加解释器了。

评估解释器

我们将从最简单的解释器开始:位于解释器堆栈底部的评估解释器。

class EvalTrace(Trace):
  pure = lift = lambda self, x: x  # no boxing in Tracers needed
  def process_primitive(self, primitive, tracers, params):
    return impl_rulesprimitive
trace_stack.append(MainTrace(0, EvalTrace, None))  # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
  for axis in sorted(axes):
    x = np.expand_dims(x, axis)
  return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl 

有了这个解释器,我们可以评估用户函数:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
print(f(3.0)) 
2.7177599838802657 

哇!就像在一个大圈子里转圈。但这种间接性的关键在于现在我们可以添加一些真正的转换。

带有jvp的前向模式自动微分

首先,一些辅助函数:

import builtins
def zeros_like(val):
  aval = get_aval(val)
  return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
  lst1, lst2 = [], []
  for x1, x2 in pairs:
    lst1.append(x1)
    lst2.append(x2)
  return lst1, lst2
def map(f, *xs):
  return list(builtins.map(f, *xs))
def zip(*args):
  fst, *rest = args = map(list, args)
  n = len(fst)
  for arg in rest:
    assert len(arg) == n
  return list(builtins.zip(*args)) 

前向模式自动微分的Tracer携带原始-切线对。Trace应用 JVP 规则。

class JVPTracer(Tracer):
  def __init__(self, trace, primal, tangent):
    self._trace = trace
    self.primal = primal
    self.tangent = tangent
  @property
  def aval(self):
    return get_aval(self.primal)
class JVPTrace(Trace):
  pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
  def process_primitive(self, primitive, tracers, params):
    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
    jvp_rule = jvp_rules[primitive]
    primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
    return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {} 

注意purelift都将一个值打包成一个带有最小上下文的JVPTracer,这是一个零切线值。

让我们添加一些用于原始函数的 JVP 规则:

def add_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
  (x,), (x_dot,) = primals, tangents
  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = greater(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = less(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp 

最后,我们添加一个转换 API 来启动跟踪:

def jvp_v1(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    out = f(*tracers_in)
    tracer_out = full_raise(trace, out)
    primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
  return primal_out, tangent_out 

而有着,我们可以进行区分!

x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0)) 
-0.9899924966004454
-0.9899924966004454 
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot) 
2.7177599838802657
2.979984993200891 
def deriv(f):
  return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.)) 
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672 
def f(x):
  if x > 0.:  # Python control flow
    return 2. * x
  else:
    return x
print(deriv(f)(3.))
print(deriv(f)(-3.)) 
2.0
1.0 

Pytrees 和展平用户函数的输入和输出

jvp_v1  的一个限制是它假设用户函数接受数组作为位置参数并生成单个数组作为输出。如果它生成一个列表作为输出怎么办?或者接受嵌套容器作为输入?在每一层处理堆栈时处理所有可能的容器将会很麻烦。相反,我们可以包装用户函数,使得包装版本接受数组作为输入并返回一个扁平的数组列表作为输出。包装器只需展开其输入,调用用户函数,并展平输出。

下面是我们希望编写 jvp 的方式,假设用户总是给我们采用数组作为输入并生成扁平数组列表作为输出的函数:

def jvp_flat(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
  return primals_out, tangents_out 

为了支持具有任意容器输入和输出的用户函数,下面是我们如何编写用户界面的 jvp 包装器:

def jvp(f, primals, tangents):
  primals_flat, in_tree = tree_flatten(primals)
  tangents_flat, in_tree2 = tree_flatten(tangents)
  if in_tree != in_tree2: raise TypeError
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
  return primals_out, tangents_out 

注意,我们必须将用户函数输出的树结构信息传递回 flatten_fun 的调用者。这些信息在我们实际运行用户函数之前是不可用的,因此 flatten_fun 只返回一个可变单元的引用,表示为一个惰性求值体。这些副作用是安全的,因为我们总是精确地运行用户函数一次。(这种安全的制度是 linear_util.py 中“linear”名称的原因,以 线性类型 的意义上)

唯一剩下的是编写 tree_flattentree_unflattenflatten_fun

显示代码单元源代码 隐藏代码单元源代码

def flatten_fun(f, in_tree):
  store = Store()
  def flat_fun(*args_flat):
    pytree_args = tree_unflatten(in_tree, args_flat)
    out = f(*pytree_args)
    out_flat, out_tree = tree_flatten(out)
    store.set_value(out_tree)
    return out_flat
  return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
  val = empty
  def set_value(self, val):
    assert self.val is empty
    self.val = val
  def __call__(self):
    return self.val 
```</details> <details class="hide above-input"><summary aria-label="Toggle hidden content">显示代码单元源代码 隐藏代码单元源代码</summary>
```py
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from typing import Callable
class NodeType(NamedTuple):
  name: str
  to_iterable: Callable
  from_iterable: Callable
def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
                         ) -> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list,  lambda l: (None, l), lambda _, xs:  list(xs))
register_pytree_node(dict,
                     lambda d: map(tuple, unzip2(sorted(d.items()))),
                     lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
  node_type: NodeType
  node_metadata: Hashable
  child_treedefs: tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
  children_iter, treedef = _tree_flatten(x)
  return list(children_iter), treedef
def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
  node_type = node_types.get(type(x))
  if node_type:
    node_metadata, children = node_type.to_iterable(x)
    children_flat, child_trees = unzip2(map(_tree_flatten, children))
    flattened = it.chain.from_iterable(children_flat)
    return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
  else:
    return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
  return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
  if treedef is leaf:
    return next(xs)
  else:
    children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
    return treedef.node_type.from_iterable(treedef.node_metadata, children) 
```</details>
通过这个处理 `jvp` 的 pytree 实现,我们现在可以处理任意输入和输出容器。这将在将来的转换中非常有用!
```py
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot) 
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]} 

使用 vmap 进行向量化批处理

首先是一对辅助函数,一个用于从未映射的抽象值生成映射的抽象值(通过移除一个轴),另一个用于在批处理维度之间移动:

def mapped_aval(batch_dim, aval):
  shape = list(aval.shape)
  del shape[batch_dim]
  return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
  if src is not_mapped:
    target_shape = list(np.shape(x))
    target_shape.insert(dst, axis_size)
    return broadcast(x, target_shape, [dst])
  elif src == dst:
    return x
  else:
    return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
  perm = [i for i in range(np.ndim(x)) if i != src]
  perm.insert(dst, src)
  return transpose(x, perm) 

用于向量化批处理的 Tracer 携带一个批处理值和一个可选整数,指示批处理轴(如果有的话)。

from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
  def __init__(self, trace, val, batch_dim: BatchAxis):
    self._trace = trace
    self.val = val
    self.batch_dim = batch_dim
  @property
  def aval(self):
    if self.batch_dim is not_mapped:
      return get_aval(self.val)
    else:
      return mapped_aval(self.batch_dim, get_aval(self.val))
  def full_lower(self):
    if self.batch_dim is not_mapped:
      return full_lower(self.val)
    else:
      return self
class BatchTrace(Trace):
  pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
  def process_primitive(self, primitive, tracers, params):
    vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
    vmap_rule = vmap_rules[primitive]
    val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
    return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
  @property
  def axis_size(self):
    return self.main.global_data
vmap_rules = {} 

在这里,我们实现了可选的 Tracer.full_lower 方法,这让我们能够在不需要的情况下去除批处理跟踪器,因为它不代表批处理值。

对于 BatchTrace,类似于 JVPTracepurelift 方法只是将一个值装箱在 BatchTracer 中,并且只提供最少的上下文,这种情况下是一个采用 not_mapped 作为标志值的 batch_dim。请注意,我们使用 MainTrace 的解释器全局数据字段来存储批处理轴的大小。

接下来,我们可以为每个原语定义批处理解释器规则:

from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
  (x, y), (x_bdim, y_bdim) = vals_in, dims_in
  if x_bdim != y_bdim:
    if x_bdim is not_mapped:
      x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
      x_bdim = y_bdim
    else:
      y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
  return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
  (x,), (x_bdim,) = vals_in, dims_in
  return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
  (x,), (x_bdim,) = vals_in, dims_in
  new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
  out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
  return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule 

最后,我们添加了一个转换 API 来启动跟踪:

def vmap_flat(f, in_axes, *args):
  axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
                if ax is not not_mapped}
  with new_main(BatchTrace, axis_size) as main:
    trace = BatchTrace(main)
    tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
                  for x, ax in zip(args, in_axes)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
  outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
                     for val_out, bdim in zip(vals_out, bdims_out)]
  return outs_transposed
def vmap(f, in_axes):
  def batched_f(*args):
    args_flat, in_tree = tree_flatten(args)
    in_axes_flat, in_tree2 = tree_flatten(in_axes)
    if in_tree != in_tree2: raise TypeError
    f_flat, out_tree = flatten_fun(f, in_tree)
    outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
    return tree_unflatten(out_tree(), outs_flat)
  return batched_f 
def add_one_to_a_scalar(scalar):
  assert np.ndim(scalar) == 0
  return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out) 
[0\. 1\. 2.]
[1\. 2\. 3.] 
def jacfwd(f, x):
  pushfwd = lambda v: jvp(f, (x,), (v,))[1]
  vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
  return vmap(pushfwd, (0,))(vecs_in)
def f(x):
  return sin(x)
jacfwd(f, np.arange(3.)) 
array([[ 1\.        ,  0\.        , -0\.        ],
       [ 0\.        ,  0.54030231, -0\.        ],
       [ 0\.        ,  0\.        , -0.41614684]]) 

这就是关于 jvpvmap 的全部内容!


JAX 中文文档(十)(3)https://developer.aliyun.com/article/1559709

相关文章
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
31 1
|
3月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
22 1
|
3月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
33 1
|
3月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
34 0
|
3月前
|
编译器 测试技术 API
JAX 中文文档(四)(4)
JAX 中文文档(四)
27 0
|
3月前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
23 0
|
3月前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
26 0
|
3月前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
45 0
|
3月前
|
存储 机器学习/深度学习 TensorFlow
JAX 中文文档(七)(5)
JAX 中文文档(七)
23 0
|
3月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
19 0