JAX 中文文档(九)(3)

简介: JAX 中文文档(九)

JAX 中文文档(九)(2)https://developer.aliyun.com/article/1559674


前向微分

JAX 在形式上实现了前向微分,即雅可比向量积(参见JAX 自动微分手册)。

现在,如果我们尝试计算 jvp 函数,会出现错误,因为我们尚未告诉 JAX 如何区分 multiply_add 原语。

# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.)) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/800067577.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1901, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1930, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented 
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).
 Given values of the arguments and perturbation of the arguments (tangents), 
 compute the output of the primitive and the perturbation of the output.
 This method must be JAX-traceable. JAX may invoke it with abstract values 
 for the arguments and tangents.
 Args:
 arg_values: a tuple of arguments
 arg_tangents: a tuple with the tangents of the arguments. The tuple has 
 the same length as the arg_values. Some of the tangents may also be the 
 special value ad.Zero to specify a zero tangent.
 Returns:
 a pair of the primal output and the tangent.
 """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now we have a JAX-traceable computation of the output. 
  # Normally, we can use the ma primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)
  _trace("Tangent evaluation:")
  # We must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
  # We do need to deal specially with Zero. Here we just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for Zero and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  
  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX 
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp 
# Tangent is: xt*y + x*yt + zt = 1.*2\. + 2.*1\. + 1\. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)> 

解释如下:

  • JAX 在 square_add_prim 中为何使用 ConcreteArray?这里没有进行抽象评估。
  • 不确定如何解释 multiply_add_prim 是如何使用 ConcreteValue 调用的,但我们却没有调用 multiply_add_abstract_eval
  • 我认为在这里展示 jaxpr 将会很有用。
JIT 的前向微分

我们可以将 JIT 应用于前向微分函数:

assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc68ae70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cca0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dc2b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a3459c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8b3270>] 

注意,我们首先抽象评估 multiply_add_value_and_jvp,它进而抽象评估 ma 的原始和切线评估(共 3 次调用 ma 原语)。然后编译这 3 次出现的原语。

反向微分

如果我们现在尝试使用反向微分,我们会看到 JAX 首先使用 multiply_add_value_and_jvp 来计算抽象值的前向微分,但随后遇到 NotImplementedError

在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp 进行抽象评估,以获取一个追踪原语,用于计算输出切线。请注意,JAX 使用具体值评估此抽象评估以进行微分点,而使用抽象值评估切线。还需注意,JAX 对第三个参数的特殊抽象切线值 Zero,反映了我们不对 square_add_prim 的第二个参数进行微分,其流向 multiply_add_prim 的第三个参数。

还需注意,在计算切线的抽象评估期间,我们将值 0.0 作为第三个参数的切线传递。这是因为在 multiply_add_value_and_jvp 的定义中使用了 make_zero 函数。

# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception: 
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 284, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/339076514.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented 

上述错误是因为缺少一个部分,JAX 无法使用前向微分代码来计算反向微分。

转置

正如上文所述,在计算反向微分时,JAX 获取了一个原语的追踪,使用前向微分计算切线。然后,JAX 以抽象方式反向解释此追踪,并对每个原语应用转置规则。

要理解正在发生的情况,请暂时考虑一个更简单的例子,函数“f(x, y) = x * y + y”。假设我们需要在点 (2., 4.) 处进行微分。JAX 将从输入 xtyt 的切线计算中生成以下 JVP 切线计算的 ft

a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt 

由于构造,切线计算在输入切线中始终是线性的。在切线计算中可能出现的唯一非线性操作符是乘法,但其中一个操作数是常量。

JAX 将通过反向处理 JVP 计算来生成反向微分计算。对于切线计算中的每个操作,它累积操作使用的变量的余切,使用操作结果的余切:

# Initialize cotangents of inputs and intermediate vars
  xct = yct = act = bct = cct = 0.
  # Initialize cotangent of the output
  fct = 1.
  # Process "ft = c + yt"
  cct += fct
  yct += fct
  # Process "c = a + b"
  act += cct
  bct += cct
  # Process "b = 2\. * yt"
  yct += 2. * bct
  # Process "a = xt * 4."
  xct += act * 4. 

可以验证该计算产生了 xct = 4.yct = 3.,这是函数 f 的偏导数。

JAX 对于可能出现在 JVP 计算中的每个原语都知道如何对其进行转置。从概念上讲,如果原语 p(x, y, z) 在参数 yz 的常量值 x 下是线性的,例如 p(x, y, z) = y*cy + z*cz,那么原语的转置是:

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) 

注意 p_transpose 获取原语输出的余切以及与原语的每个参数对应的值。对于线性参数,转置获取未定义的 _ 值,对于其他参数,获取实际的常数。转置为原语的每个参数返回一个余切值,对于常数参数返回 None 值。

特别地,

add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None) 
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.
 This method is only used when computing the backward gradient following 
 value_and_jvp, and is only needed for primitives that are used in the JVP 
 calculation for some other primitive. We need transposition for multiply_add_prim, 
 because we have used multiply_add_prim in the computation of the output_tangent in 
 multiply_add_value_and_jvp.
 In our case, multiply_add is not a linear primitive. However, it is used linearly 
 w.r.t. tangents in multiply_add_value_and_jvp:
 output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
 Always one of the first two multiplicative arguments is a constant.
 Args:
 ct: the cotangent of the output of the primitive.
 x, y, z: values of the arguments. The arguments that are used linearly
 get an ad.UndefinedPrimal value. The other arguments get a constant
 value.
 Returns:
 a tuple with the cotangent of the inputs, with the value None
 corresponding to the constant arguments.
 """
  if not ad.is_undefined_primal(x):
    # This use of multiply_add is with a constant "x"
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # This use of multiply_add is with a constant "y"
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose 

现在我们可以完成 grad 的运行:

assert api.grad(square_add_prim)(2., 10.) == 4. 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0) 

注意到两次调用 multiply_add_transpose。它们对应于在 multiply_add_value_and_jvpoutput_tangent 计算中使用 multiply_add_prim 的两次使用。第一次调用转置对应于 multiply_add_prim(xt, y, ...) 的最后使用,其中 y 是常数 2.0。

反向微分的 JIT

注意 multiply_add_value_and_jvp 的抽象评估仅使用抽象值,在 JIT 缺失时我们使用了 ConcreteArray

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d600>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6d1b30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a611410>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d930>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8e6f70>] 

批处理

批处理转换将点式计算转变为向量上的计算。如果我们现在尝试,会得到 NotImplementedError

# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.])) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/2641678767.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1214, in vmap_f
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented 

我们需要告诉 JAX 如何评估原语的批处理版本。在这种特殊情况下,multiply_add_prim 已经适用于任意维度的输入向量逐点运算。因此,批处理版本可以使用相同的 multiply_add_prim 实现。

from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.
 This must be a JAX-traceable function.
 Since the multiply_add primitive already operates pointwise on arbitrary
 dimension tensors, to batch it we can use the primitive itself. This works as
 long as both the inputs have the same dimensions and are batched along the
 same axes. The result is batched along the axis that the inputs are batched.
 Args:
 vector_arg_values: a tuple of two arguments, each being a tensor of matching
 shape.
 batch_axes: the axes that are being batched. See vmap documentation.
 Returns:
 a tuple of the result, and the result axis that was batched. 
 """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch 
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.]) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2\. 3.], [2\. 3.], [10\. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2\. 3.], [2\. 3.], [10\. 20.])
        call multiply_add_impl([2\. 3.], [2\. 3.], [10\. 20.])
        |<- multiply_add_impl = [14\. 29.]
      |<- multiply_add_prim = [14\. 29.]
    |<- multiply_add_batch = ([14\. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])> 
批处理的 JIT
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.]) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51cd10>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc68bdb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc68aaf0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc689eb0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a884960>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7f0afc6687c0, file "/tmp/ipykernel_1319/184469370.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc668a80, file "/tmp/ipykernel_1319/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/184469370.py': '/tmp/ipykernel_1319/184469370.py', '/tmp/ipykernel_1319/1392464762.py': '/tmp/ipykernel_1319/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/184469370.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69e860>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8920f0>] 


JAX 中文文档(九)(4)https://developer.aliyun.com/article/1559676

相关文章
|
9天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
13 1
|
9天前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
15 0
JAX 中文文档(六)(1)
|
9天前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
8 0
|
9天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
9天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
10 0
|
9天前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
12 0
|
9天前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
9 0
|
9天前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
9 0
|
9天前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
8 0
|
9天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
15 0