JAX 中文文档(九)(2)

简介: JAX 中文文档(九)

JAX 中文文档(九)(1)https://developer.aliyun.com/article/1559673


JAX 基元的工作方式

原文:jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

[外链图片转存中…(img-2VBOTe3m-1718950445124)]

necula@google.com,2019 年 10 月。

JAX 实现了 Python 函数的某些转换,例如 jitgradvmappmap。要转换的  Python 函数必须是 JAX 可追踪的,这意味着当 Python 函数执行时,它对数据应用的唯一操作是检查数据属性(例如形状或类型)或称为  JAX 基元的特殊操作。特别地,JAX 可追踪的函数有时会被 JAX 用抽象参数调用。例如,JAX 抽象值的一个示例是 ShapedArray(float32[2,2]),它捕获了值的类型和形状,但不是具体数据值。JAX 基元知道如何在具体数据值和 JAX 抽象值上操作。

转换后的 JAX 函数本身必须是 JAX 可追踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))

JAX 已经预定义了对应大多数 XLA 操作的基元,例如 add、matmul、sin、cos 和索引。JAX 还提供了以 JAX  基元为基础实现 numpy 函数的功能,这意味着使用 JAX 的 numpy 实现的 Python 程序是 JAX  可追踪的,因此可以进行变换。其他库可以通过在 JAX 基元的基础上实现它们来使其能够被 JAX 追踪。

JAX 基元的集合是可扩展的。可以定义一个新的基元,封装函数的行为,而不是在预定义的 JAX 基元的基础上重新实现函数。

本文档的目标是解释 JAX 基元必须支持的接口,以允许 JAX 执行其所有转换。

考虑我们想要为 JAX 添加支持三个参数的乘加函数,数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数在三个形状相同的浮点数值张量上逐点执行操作。

使用现有的基元

定义新函数的最简单方法是使用 JAX 基元或者已经用 JAX 基元编写的其他函数,例如在 jax.lax 模块中定义的函数:

from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the jax.lax primitives."""
  return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) 
square_add_lax =  14.0
grad(square_add_lax) =  4.0 

为了理解 JAX 如何内部使用这些基元,我们添加了一些帮助函数来跟踪函数调用。

#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
  """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)
def _trace_indent(msg=None):
  """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation
def _trace_unindent(msg=None):
  """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)
def trace(name):
  """A decorator for functions to trace arguments and results."""
  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
  """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])
    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res
    return func_wrapper
  return trace_func
class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False 

而不是直接使用 jax.lax 基元,我们可以使用已经用这些基元编写的其他函数,例如 jax.numpy 中的函数:

import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) 
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) =  4.0 

注意,在计算 grad 的过程中,JAX 调用了 square_add_numpymultiply_add_numpy,并使用特殊的参数 ConcreteArray(...)(在此 colab 中进一步描述)。重要的是要记住,一个 JAX 可追溯的函数必须能够不仅在具体参数上运行,还能在 JAX 可能使用的特殊抽象参数上运行。

只要函数是用 JAX 原语编写的,JAX 的可追溯性属性就得到满足。

定义新的 JAX 原语

为支持乘加功能的正确方式是使用现有的 JAX 原语,如上所示。然而,为了展示 JAX 原语的工作方式,让我们假装我们想为 JAX 添加一个新的原语来实现乘加功能。

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.
 Note that the traced arguments must be passed as positional arguments
 to `bind`. 
 """
  return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b) 

如果我们尝试调用新定义的函数,我们会得到一个错误,因为我们尚未告诉 JAX 关于新原语的语义。

with expectNotImplementedError():
  square_add_prim(2., 10.) 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_1319/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_1319/1308506715.py", line 16, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented 

原始评估规则

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.
 This function does not need to be JAX traceable.
 Args:
 x, y, z: the concrete arguments of the primitive. Will only be called with 
 concrete values.
 Returns:
 the concrete result of the primitive.
 """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl) 
<function __main__.multiply_add_impl(x, y, z)> 
assert square_add_prim(2., 10.) == 14. 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0 

JIT

现在如果我们尝试使用 jit,我们会得到一个 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented 
抽象评估规则

为了 JIT 函数以及其他转换,JAX 首先使用只有参数的形状和类型的抽象方式进行评估。这种抽象评估有多重目的:

  • 获取计算中使用的 JAX 原语序列。这个序列将被编译。
  • 计算所有向量和操作在计算中使用的形状和类型。

例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])ConcreteArray([1., 2., 3.])。在后一种情况下,JAX 使用实际的具体值包装为抽象值。

from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.
 This function does not need to be JAX traceable. It will be invoked with
 abstractions of the actual arguments. 
 Args:
 xs, ys, zs: abstractions of the arguments.
 Result:
 a ShapedArray for the result of the primitive.
 """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) 
<function __main__.multiply_add_abstract_eval(xs, ys, zs)> 

如果我们重新尝试进行 JIT 编译,我们可以看到抽象评估的过程,但是我们会遇到另一个错误,关于缺少实际的 XLA 编译规则:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception: 
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu 
XLA 编译规则

JAX 编译通过将每个原语编译成 XLA 操作的图形来工作。

这是向 JAX 添加新功能的最大障碍,因为 XLA 操作的集合是有限的,并且 JAX 已经为大多数操作预定义了原语。然而,XLA 包括一个 CustomCall 操作,可以用来封装使用 C++ 定义的任意功能。

from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.
 Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
 the results of the function.
 Does not need to be a JAX-traceable function.
 """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') 
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)> 

现在我们成功 JIT。请注意下面,JAX 首先抽象评估函数,触发 multiply_add_abstract_eval 函数,然后编译它遇到的一系列原语,包括 multiply_add。在这一点上,JAX 调用 multiply_add_xla_translation

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc664db0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc688cf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc688d70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc682b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8fd060>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a45a0d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d4ea0, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object <module> at 0x7f0afd8d6b80, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/1570919344.py': '/tmp/ipykernel_1319/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afd8fe8f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6835b0>] 

下面是另一个 jit 的用法,我们只编译关于第一个参数的部分。请注意,square_add_prim 的第二个参数是具体的,这导致第三个参数 multiply_add_abstract_evalConcreteArray。我们看到 multiply_add_abstract_eval 可以与 ShapedArrayConcreteArray 一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc666480>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc690530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6905b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc690570>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8ffc40>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d5b00, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object <module> at 0x7f0b2cd8b3c0, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/4165789807.py': '/tmp/ipykernel_1319/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69c250>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dcff0>] 


JAX 中文文档(九)(3)https://developer.aliyun.com/article/1559675

相关文章
|
2月前
|
Shell Linux 网络安全
宝塔服务器面板部署安装git通过第三方应用安装收费怎么办—bash: git: command not found解决方案-优雅草卓伊凡
宝塔服务器面板部署安装git通过第三方应用安装收费怎么办—bash: git: command not found解决方案-优雅草卓伊凡
526 3
宝塔服务器面板部署安装git通过第三方应用安装收费怎么办—bash: git: command not found解决方案-优雅草卓伊凡
|
SQL 存储 缓存
值得收藏!my.cnf配置文档详解
MySql对于开发人员来说应该都比较熟悉,不管是小白还是老码农应该都能熟练使用。但是要说到的各种参数的配置,我敢说大部分人并不是很熟悉,当我们需要优化mysql,改变某项参数的时候。还是要到处在网上查找,有点不方便。今天就把我所知道的MySql的配置文件my.cnf做一个简单的说明吧,注意,我总结的mysql是Linux环境下的。
值得收藏!my.cnf配置文档详解
|
机器学习/深度学习 算法 算法框架/工具
为什么使用C++进行机器学习开发
C++作为一种高性能语言,在某些性能要求极高或资源受限的场景下也具有非常重要的地位。C++的高效性和对底层硬件的控制能力,使其在大规模机器学习系统中发挥重要作用,尤其是当需要处理大数据或实时响应的系统时。
296 3
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
183 1
|
SQL NoSQL 数据库
深度解密 pandas 的行列转换,以及一行变多行、一列变多列
深度解密 pandas 的行列转换,以及一行变多行、一列变多列
824 0
|
机器学习/深度学习 算法 Python
深度解析机器学习中过拟合与欠拟合现象:理解模型偏差背后的原因及其解决方案,附带Python示例代码助你轻松掌握平衡技巧
【10月更文挑战第10天】机器学习模型旨在从数据中学习规律并预测新数据。训练过程中常遇过拟合和欠拟合问题。过拟合指模型在训练集上表现优异但泛化能力差,欠拟合则指模型未能充分学习数据规律,两者均影响模型效果。解决方法包括正则化、增加训练数据和特征选择等。示例代码展示了如何使用Python和Scikit-learn进行线性回归建模,并观察不同情况下的表现。
1684 3
|
机器学习/深度学习 边缘计算 PyTorch
PyTorch 与 ONNX:模型的跨平台部署策略
【8月更文第27天】深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。
1448 2
|
机器学习/深度学习 算法 数据安全/隐私保护
基于强化学习的倒立摆平衡车控制系统simulink建模与仿真
基于强化学习的倒立摆平衡控制系统利用MATLAB 2022a实现无水印仿真。此系统通过学习策略使摆维持垂直平衡。强化学习涉及状态(如角度和速度)、动作(施力)、奖励(反馈)及策略(选择动作)。采用Q-Learning算法更新动作价值函数Q(s,a),并通过DQN处理高维状态空间,利用经验回放和固定Q-targets提高学习效率和稳定性。
304 9
|
存储 编解码 数据可视化
单细胞空间|在Seurat中对基于图像的空间数据进行分析(1)
单细胞空间|在Seurat中对基于图像的空间数据进行分析(1)