JAX 中文文档(十一)(1)https://developer.aliyun.com/article/1559779
主要问题描述
vmap 移除自定义 JVP 语义问题
vmap 移除自定义 JVP 语义问题是 vmap 与具有 custom_transforms
规则的函数微分不正确组合的问题:
# old custom_transforms api to be replaced @jax.custom_transforms def f(x): return 2. * x # f_vjp :: a -> (b, CT b --o CT a) def f_vjp(x): return f(x), lambda g: 3. * x # 3 instead of 2 jax.defvjp_all(f, f_vjp) grad(f)(1.) # 3. vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.] grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
最后一行 grad-of-vmap 有一个意外的结果!通常情况下,应用 vmap
或任何非微分转换都会导致自定义微分规则被移除。(当定义了自定义 VJP 规则时,应用 jvp
会导致失败。)
问题存在于转换就像重写一样,而 vmap
转换有效地将函数重写为不再调用新引入的具有自定义规则的原语(因此 grad
不再生成自定义规则的结果)。更详细地说,custom_transforms
机制设置了这样的环境,使得评估 f(x)
应用函数
{ lambda ; ; a. let b = f_primitive a in [b] }
其中 f_primitive
是一个新的原语(为每个 custom_transforms
函数引入,并实际上为每次函数调用引入),与自定义 VJP 规则相关联。当我们计算 grad(f)(x)
时,微分机制遇到 f_primitive
并用自定义规则处理它。
然而,因为 f_primitive
对于 vmap
来说是 透明 的,即 vmap
在(有效地内联)定义 f_primitive
的基础上操作,所以函数 vmap(f)
有效地是
{ lambda ; ; a. let b = mul 2. a in [b] }
简而言之,vmap
重写函数以其基础原语及其转换规则表示,完全移除 f_primitive
。
更一般地说,因为 vmap(f)
的语义定义为调用 f,因此删除自定义导数规则在语义上是不一致的。也就是说,由于我们定义
vmap(f)(xs) == np.stack([f(x) for x in xs])
我们必须有
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
然而,当 f
具有自定义导数规则时,就不再具备这一特性,因为自定义导数规则只在右手版本中使用,而不在左手版本中使用。
这个问题并不局限于 vmap
;它适用于所有将函数 f
转换语义定义为调用函数 f
而不是重写其为另一个函数的转换。mask
转换也属于这一类。不同的微分变换和假设的所有一元函数变为余弦变换不属于这一类。
(类似自定义 vmap
规则的额外自定义规则之间的交互可能会变得更加复杂,这表明 custom_transforms
的问题框架过于广泛。)
Python 的灵活性问题
在 JAX 中,与 Autograd 和 PyTorch 一样但不适用于 TF1,Python 函数的微分是在执行和追踪函数时执行的。这种行为有几个原因让用户喜爱。
首先,而且最重要的是,它支持基于 pdb 的工作流程,例如用于检查数值或捕获 NaNs。 也就是说,用户可以使用标准的 Python 调试器和其他 Python 原生工具来调试他们的代码,甚至可以检查运行时值以理解示例中的数值行为,并捕获诸如 NaN 等基本的运行时错误。事实上,就在为这一设计相应的 PR 工作时,特别是在 odeint
原语上,我多次使用运行时值检查来调试问题,增强了我对这一在 Python 中的关键用户工作流程的信心。一个特别方便的技巧是,在自定义 VJP 规则中插入调试器断点,以在向后传递中的特定点进入调试器。
其次,它允许对 Python 原生控制流进行微分。 我们不确定在最终的软件成品中实际使用这种功能的频率,但当用户首次尝试 JAX 或 Autograd 时,他们通常会对这种自由感到印象深刻。我们在 JAX 和 Autograd 的 README、幻灯片演示和演示中包含它是有原因的。放弃这种能力将是从 Autograd 后退的一步。我们希望 JAX 拥有最好的自动微分能力。
然而,custom_transforms
机制并没有提供这种 Python 支持的灵活性。也就是说,因为它是根据来自用户函数和自定义微分规则的 Python 代码的 jaxpr 形成而实现的,这样的代码会导致抽象值追踪错误:
# old custom_transforms api to be replaced @jax.custom_transforms def f(x): if x > 0: return x else: return 0. def f_vjp(x): return ... jax.defvjp_all(f, f_vjp) grad(f)(1.) # Error!
解决方案思路
dougalm@ 已经通过 core.call
解决了这些问题的主要思想。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务框定为一个新的 Python 级别调用原语(不会添加到 jaxpr 语言中;详见下文)。这个新的调用原语与 core.call
类似,有一个关联的用户 Python 函数,但额外还有一个表示 JVP 规则的第二个 Python 可调用对象。让我们称这个新的调用原语为 custom_jvp_call
。
类似于 vmap
如何通过应用于要调用的函数来与 core.call
交互一样,变通地写成原语的柯里化版本,vmap
与 custom_jvp_call
交互,它们有效地穿过它并应用于底层的 Python 可调用对象。这种行为意味着我们已经解决了 vmap 移除自定义 JVP 语义的问题。
vmap(call(f)) == call(vmap(f))
对于新的原语 custom_jvp_call
,我们简单地对它涉及的两个函数应用 vmap
:
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
这种行为意味着我们已经解决了 vmap-移除-custom-jvp 语义问题。
jvp
变换的交互方式如人所预期的那样:它只是调用 f_jvp
,
jvp(call(f)) == call(jvp(f)) jvp(custom_jvp_call(f, f_jvp)) == f_jvp
因为custom_jvp_call
类似于core.call
(而不是像xla.xla_call
那样),它不会提升其输入的抽象级别(因为它不延迟任何内容或将任何内容转出),这意味着我们解决了 Python 灵活性问题:用户 Python 函数没有约束(除了jvp
或vjp
所需的常规函数编程约束)。
评估和编译怎么办?这两种方式是“退出”JAX 系统的两种方式,因为在这些步骤之后不能再应用额外的转换。因此,它们的规则是微不足道的:
eval(call(f)) == eval(f) jit(call(f)) == hlo_call(jit(f)) eval(custom_jvp_call(f, f_jvp)) == eval(f) jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
换言之,如果一个 JVP 规则在将custom_jvp_call(f, f_jvp)
重写为f_jvp
之前没有重写,那么当我们到达评估点eval
或用jit
转出至 XLA 时,微分永远不会被应用,因此我们只需忽略f_jvp
并且像core.call
一样行事。然而,由于下面讨论的问题,custom_jvp_call
的部分评估规则必须更加复杂,因为部分评估不仅仅用于用jit
转出至 XLA。
“初始样式”jaxpr 形成原语的唯一剩余问题与lax.scan
等有关,并且它们的转换规则也有所不同。这些原语代表了一种不同类型的“转出至 jaxpr”,与编译不同,因为我们可以在转出的 jaxpr 上执行额外的转换。也就是说,当lax.scan
形成一个 jaxpr 时,它并没有退出转换系统,因为当我们对lax.scan
应用 jvp 或 vmap 时,需要对 jaxpr 所代表的函数应用它。
另一种表述剩余问题的方式是,像lax.scan
这样的初始样式原语依赖于能够往返到一个 jaxpr 并返回到 Python 可调用对象的能力,同时保留语义。这必须意味着也要保留自定义微分规则的语义。
解决方案是使用一点动态作用域:当我们将一个初始样式原语转出至 jaxpr 时,例如在 lax_control_flow.py 中的原语,我们在全局跟踪状态上设置一个位。当该位被设置时,我们使用一个初始样式custom_jvp_call_jaxpr
原语,而不是使用最终样式的custom_jvp_call
原语,并且提前跟踪函数f
和f_jvp
到 jaxpr,以使初始样式处理更容易。custom_jvp_call_jaxpr
原语在其他方面与最终样式版本类似。
(脚注:道德上,我们在绑定custom_jvp_call_jaxpr
之前为f
和f_jvp
都形成 jaxpr,但是我们需要延迟f_jvp
的 jaxpr 形成,因为它可能调用自定义 JVP 函数,因此急速处理将导致无限递归。我们在一个 thunk 中延迟该 jaxpr 形成。)
如果我们放弃 Python 的灵活性问题,我们可以仅仅使用custom_jvp_call_jaxpr
,而不需要单独的 Python 级原语custom_jvp_call
。
API
a -> b
函数的自定义 JVP 由(a, Ta) -> (b, T b)
函数指定:
# f :: a -> b @jax.custom_jvp def f(x): return np.sin(x) # f_jvp :: (a, T a) -> (b, T b) def f_jvp(primals, tangents): x, = primals t, = tangents return f(x), np.cos(x) * t f.defjvp(f_jvp)
(有趣的自动微分说明:为了使规则适用于高阶微分,必须在 f_jvp
的主体中调用 f
;这排除了 f
内部和切线计算之间某些工作共享的类型。)
一个 a -> b
函数的自定义 VJP 是通过一个 a -> (b, c)
前向传递函数与一个 (c, CT b) -> CT a
反向传递函数指定的:
# f :: a -> b @jax.custom_vjp def f(x): return np.sin(x) # f_fwd :: a -> (b, c) def f_fwd(x): return f(x), np.cos(x) # f_bwd :: (c, CT b) -> CT a def f_bwd(cos_x, g): return (cos_x * g,) f.defvjp(f_fwd, f_bwd)
签名 a -> (b, CT b --o CT a)
更具美感,但支持它将使实现变得更复杂,可能需要妥协表达性的愿望。 Python 可调用对象之所以是不透明的(除非我们追踪它们到 jaxpr 并且迫切地执行,这会放置表达约束),在这种情况下,我们可能会返回一个具有 vmap
追踪器的可调用对象,我们需要在正向传递期间了解它们。
我们可以添加方便的包装器,例如一次为单个参数定义 JVP 规则(就像我们在原语内部做的那样)。 但因为这个提案本身已经足够复杂,我决定不使用方便的层;现在让我们保持最小的东西。
API 还有一些其他的花哨功能:
- 输入和输出类型
a
、b
和c
可以是 jaxtypes 的任意 pytrees。 - 当可以使用
inspect
模块将参数按名称(关键字参数)解析为位置时,支持这种方式。 这是对 Python 3 改进的实验性质能力以编程方式检查参数签名的一部分。 我认为这是正确的,但不完整,这是一个很好的状态。(另见 #2069。) - 可以使用
nondiff_argnums
标记参数为非可区分的,并且与jit
的static_argnums
一样,这些参数不必是 JAX 类型。 我们需要设置一种约定来传递这些参数给规则。 对于具有类型签名(d, a) -> b
的原始函数,其中d
表示不可区分的类型,JVP 规则的签名是(a, T a, d) -> T b
,VJP 规则的反向组件签名是(d, c, CT b) -> CT a
。 也就是说,在自定义 JVP 规则中,非可区分的参数在primals
和tangents
之后按顺序传递,并且在自定义 VJP 规则的反向函数中的残差之前按顺序传递。
实现注意事项
- 更新了
jax.experimental.odeint
- 由于
odeint
是一个相当复杂的自定义 VJP 规则的用户,除了只更新它以使其能够正常工作外,我还希望将其修改为新的自定义 VJP API 的规范用户,以此来测试该 API 是否良好。 - 在此过程中,我对
odeint
实现进行了其他改进:
- 删除了解开/重新解开的样板代码
- 利用
lax.scan
来消除索引更新逻辑 - 在简单的单摆基准测试中加速了 20+%。
- 对每个变换添加了自定义绑定方法,用于自定义导数调用原语
custom_jvp_call
和custom_vjp_call
。 这类似于core.call_bind
,但我们不处理 env traces:这些只是错误。 - 添加了
custom_lin
原语,它在使用自定义 VJP 规则时被分阶段转化为线性 jaxprs 以进行转置。
- 由于我们的反向模式自动微分分解为线性化、部分求值和转置,我们的自定义 VJP 规则在两个独立步骤中处理:一个在线性化期间,另一个在转置期间。
- 线性化步骤,即
custom_vjp_call
的 JVP 规则,将custom_lin
应用于切线值;custom_lin
携带用户的自定义反向传播函数,并且作为一个原语,它只有一个转置规则。 - 这一机制在#636中有更详细的描述。
- 为了防止
自定义 _vjp 和 nondiff_argnums 更新指南
原文:
jax.readthedocs.io/en/latest/jep/4008-custom-vjp-update.html
mattjj@ Oct 14 2020
本文假设您熟悉 jax.custom_vjp
,如用于 JAX 可转换 Python 函数的自定义导数规则笔记本中所述。
更新内容
在 JAX 的PR #4008之后,传递给 custom_vjp
函数的 nondiff_argnums
的参数不能是 Tracer
s(或 Tracer
的容器),这基本上意味着为了允许任意可转换的代码,nondiff_argnums
不应该用于数组值的参数。相反,nondiff_argnums
应该仅用于非数组值,如 Python 可调用对象或形状元组或字符串。
无论我们以前用 nondiff_argnums
用于数组值的地方,我们应该将它们作为常规参数传递。在 bwd
规则中,我们需要为它们生成值,但我们可以只生成 None
值来指示没有相应的梯度值。
例如,这是编写 clip_gradient
的旧方法,当 hi
和/或 lo
是来自某些 JAX 转换的 Tracer
时将无法工作。
from functools import partial import jax @partial(jax.custom_vjp, nondiff_argnums=(0, 1)) def clip_gradient(lo, hi, x): return x # identity function def clip_gradient_fwd(lo, hi, x): return x, None # no residual values to save def clip_gradient_bwd(lo, hi, _, g): return (jnp.clip(g, lo, hi),) clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
这里是新的,令人惊叹的方法,支持任意转换:
import jax @jax.custom_vjp # no nondiff_argnums! def clip_gradient(lo, hi, x): return x # identity function def clip_gradient_fwd(lo, hi, x): return x, (lo, hi) # save lo and hi values as residuals def clip_gradient_bwd(res, g): lo, hi = res return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
如果您使用旧方式而不是新方式,在可能出错的任何情况下(即将 Tracer
传递给 nondiff_argnums
参数时),您将会收到一个大声的错误。
这是一个我们实际上需要使用 custom_vjp
的情况,与 nondiff_argnums
:
from functools import partial import jax @partial(jax.custom_vjp, nondiff_argnums=(0,)) def skip_app(f, x): return f(x) def skip_app_fwd(f, x): return skip_app(f, x), None def skip_app_bwd(f, _, g): return (g,) skip_app.defvjp(skip_app_fwd, skip_app_bwd)
解释
将 Tracer
s 传递到 nondiff_argnums
参数中一直是有 bug 的。虽然有些情况下工作正常,但其他情况会导致复杂和令人困惑的错误消息。
这个 bug 的本质在于 nondiff_argnums
的实现方式很像词法闭包。但是那时候,对于Tracer
s 的词法闭包并不打算与custom_jvp
/custom_vjp
一起工作。以这种方式实现 nondiff_argnums
是一个错误!
PR #4008 修复了所有与 custom_jvp
和 custom_vjp
相关的词法闭包问题。 哇哦!也就是说,现在 custom_jvp
和 custom_vjp
函数和规则可以对 Tracer
s 进行词法闭包了。对于所有非自动微分转换,一切都会顺利进行。对于自动微分转换,我们将得到一个清晰的错误消息,说明为什么我们不能针对 custom_jvp
或 custom_vjp
关闭的值进行微分:
检测到对于一个闭包值的自定义 _jvp 函数的微分。这不被支持,因为自定义 JVP 规则仅指定如何针对显式输入参数微分自定义 _jvp 函数。
尝试将闭包值传递给
custom_jvp
函数作为参数,并调整custom_jvp
规则。
通过这种方式加强和健壮custom_jvp
和custom_vjp
时,我们发现允许custom_vjp
在其nondiff_argnums
中接受Tracer
将需要大量的簿记工作:我们需要重写用户的fwd
函数以返回这些值作为残差,并重写用户的bwd
函数以接受它们作为普通残差(而不是像在nondiff_argnums
中那样接受它们作为特殊的前导参数)。这似乎可能是可管理的,直到你考虑我们如何处理任意的 pytrees!此外,这种复杂性并非必要:如果用户代码将类似数组的不可区分参数视为常规参数和残差处理,一切都已经可以正常工作。(在 #4039 之前,JAX 可能会抱怨涉及整数值输入和输出的自动微分,但在 #4039 之后,这些问题将会解决!)
与custom_vjp
不同,将custom_jvp
与nondiff_argnums
参数(即Tracer
)一起使用是很容易的。因此,这些更新只需要在custom_vjp
中进行。
全面暂存
mattjj@ Sept 25 2020
这更像是升级指南而不是设计文档。
目录
- 简而言之
- “全面暂存”是什么以及其有何用处?
- 开启全面暂存可能导致哪些问题?
- 使用
jax.numpy
进行形状计算 - 副作用
- 基于 XLA 优化的小数值差异
- 依赖于已更改的 JAX 内部 API
- 触发 XLA 编译时错误
简而言之
发生了什么?
JAX 的跟踪基础设施发生的名为“全面暂存”(google/jax#3370)在 jax==0.2.0 中启用。此更改改善了内存性能、跟踪执行时间并简化了 jax 内部,但可能导致某些现有代码出现问题。通常情况下,问题是由于有 bug 的代码引起的,因此从长远来看最好修复这些 bug,但全面暂存也可以作为临时解决方法禁用。我们乐意帮助您进行修复!
如何知道全面暂存破坏了我的代码?
判断全面暂存是否负责的最简单方法是禁用全面暂存并查看问题是否消失。请参阅下面的“开启全面暂存可能导致哪些问题?”部分。
如何暂时禁用全面暂存?
注意:这适用于 JAX 版本 0.2.0 到 0.2.11;在 JAX 版本 0.2.12 及更高版本中无法禁用全面暂存
暂时可以通过以下方式禁用全面暂存
- 将 shell 环境变量
JAX_OMNISTAGING
设置为 falsey; - 如果你的代码使用 absl 解析标志,则将布尔标志
jax_omnistaging
设置为 falsey; - 在主文件顶部附近使用此语句:
jax.config.disable_omnistaging()
如何修复全面暂存暴露的错误?
全面暂存最常见的问题远远超过了使用 jax.numpy
计算形状值或其他跟踪时间常量。请参阅下面的代码块,快速了解示例,并详细了解其他问题,请参阅“开启全面暂存可能导致哪些问题?”部分。
现在改为:
@jit def f(x): input_size = jnp.prod(x.shape) if input_size > 100: ...
请执行以下操作:
import numpy as np @jit def f(x): input_size = np.prod(x.shape) if input_size > 100: ...
现在不再将 jax.numpy
视为 numpy
的可替代品,现在最好仅在需要在加速器(如 GPU)上执行计算时才考虑使用 jax.numpy
操作。
“全面暂存”是什么以及其有何用处?
全面暂存是 JAX 核心升级的名称,旨在从逐操作的 Python 到 XLA 分阶段进行计算,并避免在 jit
、pmap
和控制流原语中进行“跟踪时间常量折叠”。因此,全面暂存通过减少跟踪过程中的碎片化和生成更少的 XLA 编译时常量(有时会显著降低)来改善 JAX 的内存性能。它还可以通过在跟踪时间消除逐操作执行来改善跟踪性能。此外,全面暂存简化了 JAX 核心内部结构,修复了许多未解决的 bug,并为重要的即将推出的功能铺平了道路。
名称“全面暂存”意味着尽可能分阶段输出所有内容。
玩具示例
像jit
和pmap
这样的 JAX 变换将计算分阶段到 XLA。也就是说,我们将它们应用于由多个原始操作组成的函数,使得这些操作不再从 Python 中逐个执行,而是作为一个端到端优化的 XLA 计算的一部分。
但确切地说哪些操作被分阶段了?在全阶段之前,JAX 仅基于数据依赖性分阶段计算。这里有一个示例函数,后面是它在全阶段更改之前分阶段的 XLA HLO 程序:
from jax import jit import jax.numpy as jnp @jit def f(x): y = jnp.add(1, 1) return x * y f(3)
ENTRY jit_f.6 { constant.2 = pred[] constant(false) parameter.1 = s32[] parameter(0) constant.3 = s32[] constant(2) multiply.4 = s32[] multiply(parameter.1, constant.3) ROOT tuple.5 = (s32[]) tuple(multiply.4) }
注意,add
操作没有被分阶段。相反,我们只看到一个乘法。
这是从这个函数生成的 HLO,在全阶段更改之后:
ENTRY jit_f.8 { constant.2 = pred[] constant(false) parameter.1 = s32[] parameter(0) constant.3 = s32[] constant(1) constant.4 = s32[] constant(1) add.5 = s32[] add(constant.3, constant.4) multiply.6 = s32[] multiply(parameter.1, add.5) ROOT tuple.7 = (s32[]) tuple(multiply.6) }
稍微不那么玩具的示例
这里是在实践中可能出现的一个不那么玩具的示例,当我们想要创建布尔掩码时:
import jax.numpy as jnp from jax import lax @jit def select_tril(x): mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1]) return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where x = np.arange(12).reshape((3, 4)) select_tril(x)
在全阶段之前:
ENTRY jit_select_tril.8 { constant.3 = pred[] constant(false) constant.1 = pred[3,4]{1,0} constant({...}) parameter.2 = s32[3,4]{1,0} parameter(0) constant.4 = s32[] constant(0) broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={} select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5) ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6) }
select
操作被分阶段了,但用于构建常量mask
的操作却没有。而不是被分阶段,构建mask
的操作在 Python 追踪时逐个操作地执行,并且 XLA 只看到一个编译时常量constant.1
,表示mask
的值。这是不幸的,因为如果我们已经分阶段了构建mask
的操作,XLA 可以将它们融合到select
中,并避免完全实现结果。因此,我们最终会浪费内存,因为一个可能很大的常量,浪费时间分派多个未融合的逐个操作的 XLA 计算,甚至可能会导致内存碎片化。
(与为jnp.zeros_like(x)
构建零数组的广播相对应的操作被分阶段,因为 JAX 对来自google/jax#1668的非常简单表达式很懒惰。在全阶段之后,我们可以去掉那个懒惰的子语言,并简化 JAX 内部。)
创建mask
的原因不被分阶段的原因是,在全阶段之前,jit
基于数据依赖性运行。也就是说,jit
仅分阶段一个函数中对参数有数据依赖性的操作。控制流基元和pmap
的行为类似。在select_tril
的情况下,用于构建常量mask
的操作与参数 x 没有数据依赖关系,因此它们不会被分阶段;只有lax.select
调用具有数据依赖性。
使用全阶段后,jit
转换函数的动态上下文中的所有jax.numpy
调用都被分阶段到 XLA。也就是说,在全阶段后,select_tril
的计算 XLA 看到的是
ENTRY jit_select_tril.16 { constant.4 = pred[] constant(false) iota.1 = s32[3]{0} iota(), iota_dimension=0 broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0} reshape.7 = s32[3]{0} reshape(broadcast.5) broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0} iota.2 = s32[4]{0} iota(), iota_dimension=0 broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1} reshape.9 = s32[4]{0} reshape(broadcast.6) broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1} compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT parameter.3 = s32[3,4]{1,0} parameter(0) constant.12 = s32[] constant(0) broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={} select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13) ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14) }
JAX 中文文档(十一)(3)https://developer.aliyun.com/article/1559782