JAX 中文文档(九)(1)https://developer.aliyun.com/article/1559673
JAX 基元的工作方式
原文:
jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
[外链图片转存中…(img-2VBOTe3m-1718950445124)]
necula@google.com,2019 年 10 月。
JAX 实现了 Python 函数的某些转换,例如 jit
、grad
、vmap
或 pmap
。要转换的 Python 函数必须是 JAX 可追踪的,这意味着当 Python 函数执行时,它对数据应用的唯一操作是检查数据属性(例如形状或类型)或称为 JAX 基元的特殊操作。特别地,JAX 可追踪的函数有时会被 JAX 用抽象参数调用。例如,JAX 抽象值的一个示例是 ShapedArray(float32[2,2])
,它捕获了值的类型和形状,但不是具体数据值。JAX 基元知道如何在具体数据值和 JAX 抽象值上操作。
转换后的 JAX 函数本身必须是 JAX 可追踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))
。
JAX 已经预定义了对应大多数 XLA 操作的基元,例如 add、matmul、sin、cos 和索引。JAX 还提供了以 JAX 基元为基础实现 numpy 函数的功能,这意味着使用 JAX 的 numpy 实现的 Python 程序是 JAX 可追踪的,因此可以进行变换。其他库可以通过在 JAX 基元的基础上实现它们来使其能够被 JAX 追踪。
JAX 基元的集合是可扩展的。可以定义一个新的基元,封装函数的行为,而不是在预定义的 JAX 基元的基础上重新实现函数。
本文档的目标是解释 JAX 基元必须支持的接口,以允许 JAX 执行其所有转换。
考虑我们想要为 JAX 添加支持三个参数的乘加函数,数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数在三个形状相同的浮点数值张量上逐点执行操作。
使用现有的基元
定义新函数的最简单方法是使用 JAX 基元或者已经用 JAX 基元编写的其他函数,例如在 jax.lax
模块中定义的函数:
from jax import lax from jax._src import api def multiply_add_lax(x, y, z): """Implementation of multiply-add using the jax.lax primitives.""" return lax.add(lax.mul(x, y), z) def square_add_lax(a, b): """A square-add function using the newly defined multiply-add.""" return multiply_add_lax(a, a, b) print("square_add_lax = ", square_add_lax(2., 10.)) # Differentiate w.r.t. the first argument print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0 grad(square_add_lax) = 4.0
为了理解 JAX 如何内部使用这些基元,我们添加了一些帮助函数来跟踪函数调用。
#@title Helper functions (execute this cell) import functools import traceback _indentation = 0 def _trace(msg=None): """Print a message at current indentation.""" if msg is not None: print(" " * _indentation + msg) def _trace_indent(msg=None): """Print a message and then indent the rest.""" global _indentation _trace(msg) _indentation = 1 + _indentation def _trace_unindent(msg=None): """Unindent then print a message.""" global _indentation _indentation = _indentation - 1 _trace(msg) def trace(name): """A decorator for functions to trace arguments and results.""" def trace_func(func): # pylint: disable=missing-docstring def pp(v): """Print certain values more succinctly""" vtype = str(type(v)) if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: return "<JaxComputationBuilder>" elif "jaxlib.xla_extension.XlaOp" in vtype: return "<XlaOp at 0x{:x}>".format(id(v)) elif ("partial_eval.JaxprTracer" in vtype or "batching.BatchTracer" in vtype or "ad.JVPTracer" in vtype): return "Traced<{}>".format(v.aval) elif isinstance(v, tuple): return "({})".format(pp_values(v)) else: return str(v) def pp_values(args): return ", ".join([pp(arg) for arg in args]) @functools.wraps(func) def func_wrapper(*args): _trace_indent("call {}({})".format(name, pp_values(args))) res = func(*args) _trace_unindent("|<- {} = {}".format(name, pp(res))) return res return func_wrapper return trace_func class expectNotImplementedError(object): """Context manager to check for NotImplementedError.""" def __enter__(self): pass def __exit__(self, type, value, tb): global _indentation _indentation = 0 if type is NotImplementedError: print("\nFound expected exception:") traceback.print_exc(limit=3) return True elif type is None: # No exception assert False, "Expected NotImplementedError" else: return False
而不是直接使用 jax.lax
基元,我们可以使用已经用这些基元编写的其他函数,例如 jax.numpy
中的函数:
import jax.numpy as jnp import numpy as np @trace("multiply_add_numpy") def multiply_add_numpy(x, y, z): return jnp.add(jnp.multiply(x, y), z) @trace("square_add_numpy") def square_add_numpy(a, b): return multiply_add_numpy(a, a, b) print("\nNormal evaluation:") print("square_add_numpy = ", square_add_numpy(2., 10.)) print("\nGradient evaluation:") print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation: call square_add_numpy(2.0, 10.0) call multiply_add_numpy(2.0, 2.0, 10.0) |<- multiply_add_numpy = 14.0 |<- square_add_numpy = 14.0 square_add_numpy = 14.0 Gradient evaluation: call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0) call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0) |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)> |<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)> grad(square_add_numpy) = 4.0
注意,在计算 grad
的过程中,JAX 调用了 square_add_numpy
和 multiply_add_numpy
,并使用特殊的参数 ConcreteArray(...)
(在此 colab 中进一步描述)。重要的是要记住,一个 JAX 可追溯的函数必须能够不仅在具体参数上运行,还能在 JAX 可能使用的特殊抽象参数上运行。
只要函数是用 JAX 原语编写的,JAX 的可追溯性属性就得到满足。
定义新的 JAX 原语
为支持乘加功能的正确方式是使用现有的 JAX 原语,如上所示。然而,为了展示 JAX 原语的工作方式,让我们假装我们想为 JAX 添加一个新的原语来实现乘加功能。
from jax import core multiply_add_p = core.Primitive("multiply_add") # Create the primitive @trace("multiply_add_prim") def multiply_add_prim(x, y, z): """The JAX-traceable way to use the JAX primitive. Note that the traced arguments must be passed as positional arguments to `bind`. """ return multiply_add_p.bind(x, y, z) @trace("square_add_prim") def square_add_prim(a, b): """A square-add function implemented using the new JAX-primitive.""" return multiply_add_prim(a, a, b)
如果我们尝试调用新定义的函数,我们会得到一个错误,因为我们尚未告诉 JAX 关于新原语的语义。
with expectNotImplementedError(): square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0) call multiply_add_prim(2.0, 2.0, 10.0) Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1319/2844449444.py", line 2, in <module> square_add_prim(2., 10.) File "/tmp/ipykernel_1319/1393342955.py", line 48, in func_wrapper res = func(*args) File "/tmp/ipykernel_1319/1308506715.py", line 16, in square_add_prim return multiply_add_prim(a, a, b) NotImplementedError: Evaluation rule for 'multiply_add' not implemented
原始评估规则
@trace("multiply_add_impl") def multiply_add_impl(x, y, z): """Concrete implementation of the primitive. This function does not need to be JAX traceable. Args: x, y, z: the concrete arguments of the primitive. Will only be called with concrete values. Returns: the concrete result of the primitive. """ # Note that we can use the original numpy, which is not JAX traceable return np.add(np.multiply(x, y), z) # Now we register the primal implementation with JAX multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0) call multiply_add_prim(2.0, 2.0, 10.0) call multiply_add_impl(2.0, 2.0, 10.0) |<- multiply_add_impl = 14.0 |<- multiply_add_prim = 14.0 |<- square_add_prim = 14.0
JIT
现在如果我们尝试使用 jit
,我们会得到一个 NotImplementedError
:
with expectNotImplementedError(): api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) Found expected exception:
Traceback (most recent call last): File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module> api.jit(square_add_prim)(2., 10.) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
抽象评估规则
为了 JIT 函数以及其他转换,JAX 首先使用只有参数的形状和类型的抽象方式进行评估。这种抽象评估有多重目的:
- 获取计算中使用的 JAX 原语序列。这个序列将被编译。
- 计算所有向量和操作在计算中使用的形状和类型。
例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])
或 ConcreteArray([1., 2., 3.])
。在后一种情况下,JAX 使用实际的具体值包装为抽象值。
from jax import core @trace("multiply_add_abstract_eval") def multiply_add_abstract_eval(xs, ys, zs): """Abstract evaluation of the primitive. This function does not need to be JAX traceable. It will be invoked with abstractions of the actual arguments. Args: xs, ys, zs: abstractions of the arguments. Result: a ShapedArray for the result of the primitive. """ assert xs.shape == ys.shape assert xs.shape == zs.shape return core.ShapedArray(xs.shape, xs.dtype) # Now we register the abstract evaluation with JAX multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
如果我们重新尝试进行 JIT 编译,我们可以看到抽象评估的过程,但是我们会遇到另一个错误,关于缺少实际的 XLA 编译规则:
with expectNotImplementedError(): api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)) |<- multiply_add_abstract_eval = ShapedArray(float32[]) |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> |<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> Found expected exception:
Traceback (most recent call last): File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module> app.launch_new_instance() jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module> api.jit(square_add_prim)(2., 10.) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA 编译规则
JAX 编译通过将每个原语编译成 XLA 操作的图形来工作。
这是向 JAX 添加新功能的最大障碍,因为 XLA 操作的集合是有限的,并且 JAX 已经为大多数操作预定义了原语。然而,XLA 包括一个 CustomCall
操作,可以用来封装使用 C++ 定义的任意功能。
from jax._src.lib.mlir.dialects import hlo @trace("multiply_add_lowering") def multiply_add_lowering(ctx, xc, yc, zc): """The compilation to XLA of the primitive. Given an mlir.ir.Value for each argument, return the mlir.ir.Values for the results of the function. Does not need to be a JAX-traceable function. """ return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] # Now we register the lowering rule with JAX # For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) # TODO: TPU? from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>
现在我们成功 JIT。请注意下面,JAX 首先抽象评估函数,触发 multiply_add_abstract_eval
函数,然后编译它遇到的一系列原语,包括 multiply_add
。在这一点上,JAX 调用 multiply_add_xla_translation
。
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>) call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)) |<- multiply_add_abstract_eval = ShapedArray(float32[]) |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> |<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc664db0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc688cf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc688d70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc682b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8fd060>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a45a0d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d4ea0, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object <module> at 0x7f0afd8d6b80, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/1570919344.py': '/tmp/ipykernel_1319/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afd8fe8f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1)) |<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6835b0>]
下面是另一个 jit
的用法,我们只编译关于第一个参数的部分。请注意,square_add_prim
的第二个参数是具体的,这导致第三个参数 multiply_add_abstract_eval
是 ConcreteArray
。我们看到 multiply_add_abstract_eval
可以与 ShapedArray
和 ConcreteArray
一起使用。
assert api.jit(lambda x, y: square_add_prim(x, y), static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0) call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0) call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)) |<- multiply_add_abstract_eval = ShapedArray(float32[]) |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> |<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)> call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc666480>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc690530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6905b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc690570>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8ffc40>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d5b00, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object <module> at 0x7f0b2cd8b3c0, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/4165789807.py': '/tmp/ipykernel_1319/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69c250>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>)) |<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dcff0>]
JAX 中文文档(九)(3)https://developer.aliyun.com/article/1559675