JAX 中文文档(十)(3)

简介: JAX 中文文档(十)

JAX 中文文档(十)(2)https://developer.aliyun.com/article/1559708


第二部分:Jaxprs

下一个即将到来的转换是jit用于即时编译,以及vjp用于反向模式自动微分。(grad仅仅是vjp的一个小包装器。) 而jvpvmap只需要每个Tracer携带一点额外的上下文,对于jitvjp,我们需要更丰富的上下文:我们需要代表程序。也就是说,我们需要 jaxprs!

Jaxprs 是 JAX 的内部程序的中间表示。它们是显式类型化的、功能性的、一阶的,并且处于 ANF 形式。因为jit的目的是将计算分阶段出 Python,所以我们需要一个程序表示。对于任何我们想要分阶段的计算,我们需要能够将其表示为数据,并且在追踪 Python 函数时逐步构建它。类似地,vjp需要一种方法来表示反向模式自动微分的后向传播计算。我们为这两个需求使用相同的 jaxpr 程序表示。

(构建程序表示是最free种类的追踪转换,因此除了处理本地 Python 控制流问题外,任何转换都可以通过首先追踪到 jaxpr,然后解释 jaxpr 来实现。)

Jaxpr 数据结构

jaxpr 术语的语法大致为:

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }
binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>
eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ... 

类型的语法如下:

jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ... 

我们如何将这些表示为 Python 数据结构?我们重复使用 ShapedArrays 来表示类型,并且可以用几个 Python 结构来表示术语语法:

class Var:
  aval: ShapedArray
  def __init__(self, aval): self.aval = aval
class Lit:
  val: Any
  aval: ShapedArray
  def __init__(self, val):
    self.aval = aval = raise_to_shaped(get_aval(val))
    self.val = np.array(val, aval.dtype)
Atom = Union[Var, Lit]
class JaxprEqn(NamedTuple):
  primitive: Primitive
  inputs: list[Atom]
  params: dict[str, Any]
  out_binders: list[Var]
class Jaxpr(NamedTuple):
  in_binders: list[Var]
  eqns: list[JaxprEqn]
  outs: list[Atom]
  def __hash__(self): return id(self)
  __eq__ = op.is_
def raise_to_shaped(aval):
  return ShapedArray(aval.shape, aval.dtype) 

对 jaxpr 进行类型检查涉及检查是否存在未绑定的变量,变量是否仅绑定一次,以及每个方程的原始应用类型是否与输出绑定器的类型匹配。

class JaxprType(NamedTuple):
  in_types:  list[ShapedArray]
  out_types: list[ShapedArray]
  def __repr__(self):
    in_types = ', '.join(aval.str_short() for aval in self.in_types)
    out_types = ', '.join(aval.str_short() for aval in self.out_types)
    return f'({in_types}) -> ({out_types})'
def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
  env: set[Var] = set()
  for v in jaxpr.in_binders:
    if v in env: raise TypeError
    env.add(v)
  for eqn in jaxpr.eqns:
    in_types = [typecheck_atom(env, x) for x in eqn.inputs]
    out_types = abstract_eval_ruleseqn.primitive
    for out_binder, out_type in zip(eqn.out_binders, out_types):
      if not out_type == out_binder.aval: raise TypeError
    for out_binder in eqn.out_binders:
      if out_binder in env: raise TypeError
      env.add(out_binder)
  in_types = [v.aval for v in jaxpr.in_binders]
  out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
  return JaxprType(in_types, out_types)
def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
  if isinstance(x, Var):
    if x not in env: raise TypeError("unbound variable")
    return x.aval
  elif isinstance(x, Lit):
    return raise_to_shaped(get_aval(x.val))
  else:
    assert False 

我们可以使用一个简单的解释器将 jaxpr 表示的函数应用于参数。

def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
  env: dict[Var, Any] = {}
  def read(x: Atom) -> Any:
    return env[x] if type(x) is Var else x.val
  def write(v: Var, val: Any) -> None:
    assert v not in env  # single-assignment
    env[v] = val
  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_vals = map(read, eqn.inputs)
    outs = bind(eqn.primitive, *in_vals, **eqn.params)
    map(write, eqn.out_binders, outs)
  return map(read, jaxpr.outs)
def jaxpr_as_fun(jaxpr: Jaxpr):
  return lambda *args: eval_jaxpr(jaxpr, args) 

通过在解释器中使用bind,这个解释器本身是可追踪的。

使用追踪构建 jaxprs

现在我们有了 jaxprs 作为一个数据结构,我们需要从追踪 Python 代码产生它们的方法。一般来说,我们追踪到 jaxpr 有两种变体;jit使用其中一种,而vjp使用另一种。我们将从jit使用的变体开始,这也被控制流原语如lax.condlax.while_looplax.scan所使用。

def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
  assert 0 <= n <= len(lst)
  return lst[:n], lst[n:]
def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
  assert len(bs) == len(l)
  lists = lst1, lst2 = [], []
  for b, x in zip(bs, l):
    lists[b].append(x)
  return lst1, lst2 
# NB: the analogous class in JAX is called 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
  __slots__ = ['aval']
  aval: ShapedArray
  def __init__(self, trace, aval):
    self._trace = trace
    self.aval = aval
# NB: the analogous class in JAX is called 'DynamicJaxprTrace'
class JaxprTrace(Trace):
  def new_arg(self, aval: ShapedArray) -> JaxprTracer:
    aval = raise_to_shaped(aval)
    tracer = self.builder.new_tracer(self, aval)
    self.builder.tracer_to_var[id(tracer)] = Var(aval)
    return tracer
  def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
    tracer = self.builder.const_tracers.get(id(val))
    if tracer is None:
      tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
      self.builder.add_const(tracer, val)
    return tracer
  pure = lift = get_or_make_const_tracer
  def process_primitive(self, primitive, tracers, params):
    avals_in = [t.aval for t in tracers]
    avals_out = abstract_eval_rulesprimitive
    out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
    inputs = [self.builder.getvar(t) for t in tracers]
    outvars = [self.builder.add_var(t) for t in out_tracers]
    self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
    return out_tracers
  @property
  def builder(self):
    return self.main.global_data
# NB: in JAX, we instead attach abstract eval rules to Primitive instances
abstract_eval_rules = {} 

注意,我们在解释器全局数据中保持一个构建器对象,该对象跟踪变量、常量和等式,随着我们构建 jaxpr 而逐步积累。

class JaxprBuilder:
  eqns: list[JaxprEqn]
  tracer_to_var: dict[int, Var]
  const_tracers: dict[int, JaxprTracer]
  constvals: dict[Var, Any]
  tracers: list[JaxprTracer]
  def __init__(self):
    self.eqns = []
    self.tracer_to_var = {}
    self.const_tracers = {}
    self.constvals = {}
    self.tracers = []
  def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
    tracer = JaxprTracer(trace, aval)
    self.tracers.append(tracer)
    return tracer
  def add_eqn(self, eqn: JaxprEqn) -> None:
    self.eqns.append(eqn)
  def add_var(self, tracer: JaxprTracer) -> Var:
    assert id(tracer) not in self.tracer_to_var
    var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
    return var
  def getvar(self, tracer: JaxprTracer) -> Var:
    var = self.tracer_to_var.get(id(tracer))
    assert var is not None
    return var
  def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
    var = self.add_var(tracer)
    self.const_tracers[id(val)] = tracer
    self.constvals[var] = val
    return var
  def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
            ) -> tuple[Jaxpr, list[Any]]:
    constvars, constvals = unzip2(self.constvals.items())
    t2v = lambda t: self.tracer_to_var[id(t)]
    in_binders = constvars + [t2v(t) for t in in_tracers]
    out_vars = [t2v(t) for t in out_tracers]
    jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
    typecheck_jaxpr(jaxpr)
    jaxpr, constvals = _inline_literals(jaxpr, constvals)
    return jaxpr, constvals 
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
  const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
  scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
  new_const_binders, lit_binders = partition_list(scalars, const_binders)
  new_consts, lit_vals = partition_list(scalars, consts)
  literals = dict(zip(lit_binders, map(Lit, lit_vals)))
  new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
                       eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
  new_outs = [literals.get(x, x) for x in jaxpr.outs]
  new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
  typecheck_jaxpr(new_jaxpr)
  return new_jaxpr, new_consts 

我们需要JaxprTrace.process_primitive的规则基本上是原始应用的类型规则:给定原始应用、其参数和输入的类型,规则必须生成一个输出类型,然后与输出的JaxprTracer一起打包。我们可以使用抽象评估规则来实现相同的目的,尽管它们可能更加通用(因为抽象评估规则必须接受  ConcreteArray 输入,并且因为它们只需返回可能输出集的上限,它们也可以生成 ConcreteArray  输出)。我们将重用这些抽象评估规则用于其他生成 jaxpr 的跟踪机制,其中额外的通用性是有用的。

def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
  return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval
def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if x.shape != y.shape: raise TypeError
  return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval
def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
  return [ShapedArray(x.shape, x.dtype)]
abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval
def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
                             ) -> list[ShapedArray]:
  axis_ = set(axis)
  new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
  return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval
def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
                            axes: Sequence[int]) -> list[ShapedArray]:
  return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval 

要验证我们的 jaxprs 实现,我们可以添加一个make_jaxpr转换和一个漂亮的打印机:

from functools import lru_cache
@lru_cache()  # ShapedArrays are hashable
def make_jaxpr_v1(f, *avals_in):
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)
  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    trace = JaxprTrace(main)
    tracers_in = [trace.new_arg(aval) for aval in avals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree() 

显示代码单元格源代码 隐藏代码单元格源代码

from collections import defaultdict
import string
class PPrint:
  lines: list[tuple[int, str]]
  def __init__(self, lines):
    self.lines = lines
  def indent(self, indent: int) -> 'PPrint':
    return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])
  def __add__(self, rhs: 'PPrint') -> 'PPrint':
    return PPrint(self.lines + rhs.lines)
  def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
    if not rhs.lines: return self
    if not self.lines: return rhs
    indent, s = self.lines[-1]
    indented_block = rhs.indent(indent + len(s))
    common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
    return PPrint(self.lines[:-1]
                  + [(indent, common_line)]
                  + indented_block.lines[1:])
  def __str__(self) -> str:
    return '\n'.join(' ' * indent + s for indent, s in self.lines)
def pp(s: Any) -> PPrint:
  return PPrint([(0, line) for line in str(s).splitlines()])
def vcat(ps: list[PPrint]) -> PPrint:
  return sum(ps, pp(''))
def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
  namegen = (''.join(s) for r in it.count(1)
             for s in it.permutations(string.ascii_lowercase, r))
  names = defaultdict(lambda: next(namegen))
  in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
  eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
  outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
                   for v in jaxpr.outs)
  return (pp(f'{{ lambda {in_binders} .') +
          ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))
def var_str(names: defaultdict[Var, str], v: Var) -> str:
  return f'{names[v]}:{v.aval.str_short()}'
def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  rule = pp_rules.get(eqn.primitive)
  if rule:
    return rule(names, eqn)
  else:
    lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
    rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
           pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                       for x in eqn.inputs)))
    return lhs >> pp(' = ') >> rhs
def pp_params(params: dict[str, Any]) -> PPrint:
  items = sorted(params.items())
  if items:
    return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
  else:
    return pp(' ')
Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {} 
```</details>
```py
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr)) 
{ lambda a:float64[] .
  let b:float64[] = mul 2.0 a
  in ( b ) }
(float64[]) -> (float64[]) 

但是这里有一个限制:由于find_top_trace是通过数据依赖操作的,make_jaxpr_v1无法将其给定的 Python 可调用对象执行的所有原始操作分阶段处理出来。例如:

jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr) 
{ lambda  .
  let 
  in ( 4.0 ) } 

这正是omnistaging修复的问题。我们希望确保make_jaxpr启动的JaxprTrace始终被应用,而不管bind的任何输入是否被装箱在相应的JaxprTracer实例中。我们可以通过使用第一部分定义的dynamic_trace全局变量来实现这一点:

@contextmanager
def new_dynamic(main: MainTrace):
  global dynamic_trace
  prev_dynamic_trace, dynamic_trace = dynamic_trace, main
  try:
    yield
  finally:
    dynamic_trace = prev_dynamic_trace
@lru_cache()
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
               ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)
  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    with new_dynamic(main):
      trace = JaxprTrace(main)
      tracers_in = [trace.new_arg(aval) for aval in avals_in]
      outs = f(*tracers_in)
      tracers_out = [full_raise(trace, out) for out in outs]
      jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()
jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr) 
{ lambda  .
  let a:float64[] = mul 2.0 2.0
  in ( a ) } 

以这种方式使用dynamic_trace在概念上与将当前解释器堆栈存储并使用JaxprTrace作为底部开始新的解释器堆栈是相同的。也就是说,比JaxprTrace.process_primitive低的堆栈解释器不会被应用(因为它不调用bind),尽管如果被跟踪到 jaxpr 的 Python 可调用对象本身使用转换,那么这些转换可以被推送到位于JaxprTrace上面的解释器堆栈中。但是临时存储解释器堆栈会破坏系统状态。dynamic_trace标记通过保持系统状态更简单来实现相同的目标。

这就是 jaxprs 的全部内容!有了 jaxprs,我们可以实现其余的主要 JAX 特性。


JAX 中文文档(十)(4)https://developer.aliyun.com/article/1559710

相关文章
|
3月前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
33 0
|
3月前
|
测试技术 TensorFlow 算法框架/工具
JAX 中文文档(五)(2)
JAX 中文文档(五)
35 0
|
3月前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
40 0
|
3月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
27 0
|
3月前
|
安全 编译器 TensorFlow
JAX 中文文档(四)(5)
JAX 中文文档(四)
22 0
|
3月前
|
机器学习/深度学习
JAX 中文文档(六)(5)
JAX 中文文档(六)
24 0
|
3月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
81 0
|
3月前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
26 0
|
3月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
38 0
|
3月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
44 0