JAX 中文文档(九)(2)

简介: JAX 中文文档(九)

JAX 中文文档(九)(1)https://developer.aliyun.com/article/1559673


JAX 基元的工作方式

原文:jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

[外链图片转存中…(img-2VBOTe3m-1718950445124)]

necula@google.com,2019 年 10 月。

JAX 实现了 Python 函数的某些转换,例如 jitgradvmappmap。要转换的  Python 函数必须是 JAX 可追踪的,这意味着当 Python 函数执行时,它对数据应用的唯一操作是检查数据属性(例如形状或类型)或称为  JAX 基元的特殊操作。特别地,JAX 可追踪的函数有时会被 JAX 用抽象参数调用。例如,JAX 抽象值的一个示例是 ShapedArray(float32[2,2]),它捕获了值的类型和形状,但不是具体数据值。JAX 基元知道如何在具体数据值和 JAX 抽象值上操作。

转换后的 JAX 函数本身必须是 JAX 可追踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))

JAX 已经预定义了对应大多数 XLA 操作的基元,例如 add、matmul、sin、cos 和索引。JAX 还提供了以 JAX  基元为基础实现 numpy 函数的功能,这意味着使用 JAX 的 numpy 实现的 Python 程序是 JAX  可追踪的,因此可以进行变换。其他库可以通过在 JAX 基元的基础上实现它们来使其能够被 JAX 追踪。

JAX 基元的集合是可扩展的。可以定义一个新的基元,封装函数的行为,而不是在预定义的 JAX 基元的基础上重新实现函数。

本文档的目标是解释 JAX 基元必须支持的接口,以允许 JAX 执行其所有转换。

考虑我们想要为 JAX 添加支持三个参数的乘加函数,数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数在三个形状相同的浮点数值张量上逐点执行操作。

使用现有的基元

定义新函数的最简单方法是使用 JAX 基元或者已经用 JAX 基元编写的其他函数,例如在 jax.lax 模块中定义的函数:

from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the jax.lax primitives."""
  return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) 
square_add_lax =  14.0
grad(square_add_lax) =  4.0 

为了理解 JAX 如何内部使用这些基元,我们添加了一些帮助函数来跟踪函数调用。

#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
  """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)
def _trace_indent(msg=None):
  """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation
def _trace_unindent(msg=None):
  """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)
def trace(name):
  """A decorator for functions to trace arguments and results."""
  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
  """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res
    return func_wrapper
  return trace_func
class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False 

而不是直接使用 jax.lax 基元,我们可以使用已经用这些基元编写的其他函数,例如 jax.numpy 中的函数:

import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) 
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) =  4.0 

注意,在计算 grad 的过程中,JAX 调用了 square_add_numpymultiply_add_numpy,并使用特殊的参数 ConcreteArray(...)(在此 colab 中进一步描述)。重要的是要记住,一个 JAX 可追溯的函数必须能够不仅在具体参数上运行,还能在 JAX 可能使用的特殊抽象参数上运行。

只要函数是用 JAX 原语编写的,JAX 的可追溯性属性就得到满足。

定义新的 JAX 原语

为支持乘加功能的正确方式是使用现有的 JAX 原语,如上所示。然而,为了展示 JAX 原语的工作方式,让我们假装我们想为 JAX 添加一个新的原语来实现乘加功能。

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
 Note that the traced arguments must be passed as positional arguments
 to `bind`. 
 """
  return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b) 

如果我们尝试调用新定义的函数,我们会得到一个错误,因为我们尚未告诉 JAX 关于新原语的语义。

with expectNotImplementedError():
  square_add_prim(2., 10.) 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_1319/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_1319/1308506715.py", line 16, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented 

原始评估规则

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.
 This function does not need to be JAX traceable.
 Args:
 x, y, z: the concrete arguments of the primitive. Will only be called with 
 concrete values.
 Returns:
 the concrete result of the primitive.
 """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl) 
<function __main__.multiply_add_impl(x, y, z)> 
assert square_add_prim(2., 10.) == 14. 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0 

JIT

现在如果我们尝试使用 jit,我们会得到一个 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented 
抽象评估规则

为了 JIT 函数以及其他转换,JAX 首先使用只有参数的形状和类型的抽象方式进行评估。这种抽象评估有多重目的:

  • 获取计算中使用的 JAX 原语序列。这个序列将被编译。
  • 计算所有向量和操作在计算中使用的形状和类型。

例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])ConcreteArray([1., 2., 3.])。在后一种情况下,JAX 使用实际的具体值包装为抽象值。

from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.
 This function does not need to be JAX traceable. It will be invoked with
 abstractions of the actual arguments. 
 Args:
 xs, ys, zs: abstractions of the arguments.
 Result:
 a ShapedArray for the result of the primitive.
 """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) 
<function __main__.multiply_add_abstract_eval(xs, ys, zs)> 

如果我们重新尝试进行 JIT 编译,我们可以看到抽象评估的过程,但是我们会遇到另一个错误,关于缺少实际的 XLA 编译规则:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception: 
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu 
XLA 编译规则

JAX 编译通过将每个原语编译成 XLA 操作的图形来工作。

这是向 JAX 添加新功能的最大障碍,因为 XLA 操作的集合是有限的,并且 JAX 已经为大多数操作预定义了原语。然而,XLA 包括一个 CustomCall 操作,可以用来封装使用 C++ 定义的任意功能。

from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.
 Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
 the results of the function.
 Does not need to be a JAX-traceable function.
 """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') 
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)> 

现在我们成功 JIT。请注意下面,JAX 首先抽象评估函数,触发 multiply_add_abstract_eval 函数,然后编译它遇到的一系列原语,包括 multiply_add。在这一点上,JAX 调用 multiply_add_xla_translation

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc664db0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc688cf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc688d70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc682b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8fd060>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a45a0d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d4ea0, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object <module> at 0x7f0afd8d6b80, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/1570919344.py': '/tmp/ipykernel_1319/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afd8fe8f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6835b0>] 

下面是另一个 jit 的用法,我们只编译关于第一个参数的部分。请注意,square_add_prim 的第二个参数是具体的,这导致第三个参数 multiply_add_abstract_evalConcreteArray。我们看到 multiply_add_abstract_eval 可以与 ShapedArrayConcreteArray 一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc666480>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc690530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6905b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc690570>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8ffc40>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d5b00, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object <module> at 0x7f0b2cd8b3c0, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/4165789807.py': '/tmp/ipykernel_1319/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69c250>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dcff0>] 


JAX 中文文档(九)(3)https://developer.aliyun.com/article/1559675

相关文章
|
9天前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
15 0
JAX 中文文档(六)(1)
|
9天前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
16 0
|
9天前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
12 0
|
9天前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
12 0
|
9天前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
8 0
|
9天前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
11 0
|
9天前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
10 0
|
9天前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
12 0
|
9天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
14 0
|
9天前
|
存储 并行计算 开发工具
JAX 中文文档(十)(1)
JAX 中文文档(十)
9 0