JAX 中文文档(十二)(3)https://developer.aliyun.com/article/1559716
jax.extend:一个用于扩展的模块
@froystig, @sharadmv, @jakevdp, @yashk2810
2023 年 5 月
import jax.extend as jex
多个项目依赖于 JAX 的代码库内部,通常用于使用其核心机制(例如编写其 IR 上的转换)或扩展它(例如定义新的原语)。这些依赖的两个挑战是(a)我们的内部结构并不都是为外部使用而设计的,以及(b)绕过 JAX 的公共 API 是不受支持的。换句话说,我们的内部经常被用作库,但既不像库那样结构化也不像库那样更新。
此提案考虑引入一个jax.extend
模块,定义 JAX 一些内部组件的库视图。我们将其视为第二层 API,仍然基本不保证兼容性政策,但希望在发生更改时更容易发现。
jax.extend
的受众包括与 JAX 相关的 Python 库,如Oryx,jax-triton等,以及进行函数转换、自动微分系统、数值编程编译器前端等实验的项目。
本说明概述了jax.extend
现在和将来可能的样子。它没有详细列出所有细节,而是建议我们开始逐步开发这个模块。
注意,jax.extend
与jax.experimental
不同,后者是新功能和正在进行的想法的一个暂存场所。通常,jax.experimental
中的工作最终会进入另一个 JAX 模块或被完全移除。
没有兼容性政策
为了保持开发的开销低,jax.extend
不会遵循公共API 兼容性政策。它将承诺没有弃用窗口,也没有版本间的向后兼容性。每个发布都可能会破坏现有的调用者,没有简单的回退措施(例如没有重新引入先前行为的标志)。我们将依赖变更日志来指出这些更改。
调用jax.extend
的调用者可能会发现在 JAX 发布时与其代码一起定期升级对他们有用。这是当今依赖 JAX 内部的项目的一个常见习惯。不同之处在于现在它将以更好的意图和更好的库设计和命名帮助中,伴随着变更日志公告的形式出现。
逐步开发
没有兼容性政策使得在实施上更容易上手:第一天,我们可以从内部包(如jax._src
)中移植少量符号到今天的jax.core
和jax.interpreters
。然后我们可以迭代改进。
可能的模块概述
我们可以设想,最终jax.extend
可能包括以下模块:
core
– 原语,Jaxpr IR 等。interpreters
– 核心转换(例如自动微分、批处理)和降低。random
– 随机位生成、关键分割和折叠、关键数组。sharding
– 关于分布式数组的额外功能。
最初在模块中可能还有其他符号,例如jex.api_util
,随着我们的工作,我们将移除或替换它们。其他的时间会决定。例如,jex.lib
可能在短期内提供访问 jexlib 的入口点,但是目前还不清楚我们是否想长期保留它。
对每个这些内容可能包含的一些初步想法的追踪。
jax.extend.core
这应该至少使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...)
的输出)。支持这一点可能涉及提供:
- 访问现有的核心系统原语,例如今天的
jax._src.lax.add_p
。 - 访问 IR 类型,例如当前的
jax._src.core.ShapedArray
。 - 用于检查和漂亮打印 jaxprs 的功能。
- 明确构建 jaxprs 的功能,而不是通过
jax.make_jaxpr
分阶段地阶段 Python 函数(或不阶段化!)。
在初始化时,这个模块将包含比定义原语和规则所需更多的符号,包括在设置“最终风格转换”时使用的各种名称,例如当前的jax._src.core.Trace
和Tracer
类。我们可以重新审视jex.core
是否应该支持初始风格方法以及是否可以通过比完全暴露Trace
和Tracer
更狭窄的 API 来支持最终风格扩展。Oryx可能会帮助指导这些决策。
我们还可以考虑将make_jaxpr
本身迁移到jax.core
中。
jax.extend.interpreters
此模块将提供注册各种原语转换规则的手段 —— 定义它们在自动微分、批处理、降低等方面的行为。
最初将反映jax._src.interpreters
,提供模块ad
、batching
、partial_eval
(用于将 Python 编程转换为 Jaxpr,并用于自动微分中的线性化)、mlir
、pxla
和xla
。前三者可能可以通过jax.core
中的单一原语扩展 API 替换。用于降低的后三者可以简化为一个模块,也许。
今天,为了编写转换规则,例如用于自动微分和批处理的规则,调用者可能需要与跟踪器相关的符号,例如JVPTracer
和BatchTracer
。以后可能可以避免这种情况,并允许我们从jax
中移除跟踪器类型。
这个模块加上jex.core
应该足以复制今天的自定义原语教程(例如我们的教程和dfm 的教程)。例如,定义一个原语及其在jax.jit
下的行为可能如下(在短期内):
from jax.extend import core # Previously: from jax import core from jax.extend.interpreters import mlir # ... and similarly mul_add_p = core.Primitive('mul_add') mul_add_p.def_impl(lambda x, y, z: x * y + z) @mul_add_p.def_abstract_eval def mul_add_abstract(x_sa, y_sa, z_sa): return core.ShapedArray(x_sa.shape, x_sa.dtype) def mul_add_mlir(ctx, xc, yc, zc): add = mlir.hlo.AddOp mul = mlir.hlo.MulOp return add(mul(xc, yc), zc).results mlir.register_lowering(mul_add_p, mul_add_mlir) import jax print(mul_add_p.bind(2, 3, 4)) # -> 10 print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
jax.extend.random
这个模块可以暴露出我们定义新的随机数生成器实现的机制,并提供用于处理 PRNG 密钥内部的函数(参见问题#9263),例如当前的jax._src.prng.random_wrap
和random_unwrap
。
它还可以暴露出构成内置随机数生成器实现基础的键控哈希函数,例如jax._src.prng.threefry_2x32
。
jax.extend.sharding
这个模块可以暴露出用于分片分布式数组的低级实用工具。
目前我们只考虑了一项。XLA 编译器的数组分片格式比JAX 提供的那些更具表现力。我们可以将其作为jex.sharding.XlaOpShardingProto
提供,对应于今天内部的jax._src.lib.xla_client.OpSharding
。
复制引发收集的有效转置
mattjj@,dougalm@
2023 年 8 月
动机
我们在自动转置包含某些收集的shmap
中遇到了效率问题。问题出现在psum
和all_gather
,特别是当收集的输出作为未映射的输出返回给调用者时。这并不是一个边缘情况:例如,在应用grad
到基于shmap
的批量数据并行神经网络损失函数时,使用psum
来计算总损失。
我们已经知道这个问题有一段时间了。与pmap
类似的问题存在,尽管通过在pmap
内部而不是外部保留grad
来解决了这个问题。不完全的带有名称的avals-with-names
工作的一个主要目标是解决这个转置效率问题的一个版本。这篇文档借鉴了这些想法,同时对其进行了扩展和修订,以处理更多情况,并且更易于落地。事实上,这里提出的解决方案只影响shmap
的实现。其余系统不需要更改(暂时)。
这篇文档的主要目的是定义这个转置效率问题,并提出一个易于落地的解决方案。
这篇文档不涉及:
- 数组上的逻辑轴名称(这里的唯一轴名称与
shmap
和 OGpmap
中的轴名称一样); - 更改自动微分语义(所有数字和(非)错误保持不变,我们只是提高效率);
- 允许用户代码反映任何新信息,或者实际上根本不影响用户代码。
问题:psum
或all_gather
的有效转置取决于共享设备上的余切是否不变
考虑这个半真实的例子,旨在类似于一个复制参数批量数据并行损失函数:
devices = jax.devices() # 8 devices @partial(shmap, mesh=Mesh(devices, ('batch',)), in_specs=(P(None, None), P('batch', None)), out_specs=P()) def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) local_loss = jnp.mean(jnp.sum(predictions - targets, -1)) global_loss = lax.pmean(local_loss, 'batch')) return global_loss
注意out_specs=P()
,它指示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。
在loss
示例中的大多数细节并不重要。对于我们的目的来说,唯一重要的是我们在最后应用了psum
(或者更确切地说是pmean = lambda x, name: psum(x, name) / psum(1, name)
)。因此,一个精简版本看起来像这样:
# Example 1: shmap involving psum and unmapped output with inefficient transpose f1 = shmap(lambda x: psum(g(x), 'i'), in_specs=P('i'), out_specs=P())
甚至通过抑制mesh
参数简化了符号。在接下来的例子中,可以从上下文中推断出来。
什么样的转置看起来像?写t
来表示函数转置,我们可以通过应用下面的函数¿f1_transpose?
有效地评估任意ybar
对应的t(f1)(ybar)
:
# An efficient "transpose" of Example 1 (but don't transpose this again!) ¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但这并不是我们当前获得的转置t(f1)
。
相反,当前的转置配方大致是我们交换in_specs
和out_specs
,对未映射输出进行一些分区重缩放,并转置主体。因为psum
本身是其自身的转置(作为全归约和的总和),我们最终会产生这个转置:
# The transpose we currently get for Example 1 (which is fine to transpose again) t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')), in_specs=P(), out_specs=P('i'))
这个转置虽然得到了正确的数字,但是很浪费。我们从转置的 in_specs=P()
静态地知道 ybar
对于每个函数实例都具有相同的值,即其值对于沿着被命名为 i
的网格轴的设备是不变的,然而我们还是对它应用了 psum
!这使用了昂贵的通信来将每个设备上的值乘以 8。(这里的 8 指的是轴 i
的大小。除以 8 来自于原始函数的 out_specs=P()
;它和微不足道的 psum
基本上互相抵消了。)
我们做错了什么?我们没有利用 cotangents
ybar
对应于 f1
的未映射输出是设备不变的这一事实;相反,我们像防御性地 psum
它们一样处理它们,就像 psum
的转置不能确定它们一样。有时 psum
是必要的,比如对于关于其第一个参数的 f2
的转置:
# Example 2: shmap involving psum and *mapped* output with efficient transpose f2 = shmap(lambda x, y: psum(g(x), 'i') * y, in_specs=(P('i'), P('i')), out_specs=P('i')) # The transpose we currently get for Example 2 is efficient t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')), in_specs=(P('i'), P('i')), out_specs=P('i'))
直观地说,如果我们的转置机制能区分示例 1 和示例 2,我们可以通过尽可能避免在可能的情况下避免 psum
和除法来做得更好。
低效的示例甚至可以更小。考虑转置这个被诅咒的恒等函数:
# Example 3: cursed identity cursed_identity = shmap(lambda x: x, P(), P()) # Currently we get these inefficient transposes t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P()) t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P()) ...
随着我们的转置越来越多,它变得越来越大。真丢人!
而 psum
并不是唯一的问题。类似的情况也适用于 all_gather
:
# Example 4: all_gather to an unmapped output f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P()) # Currently we get this inefficient transpose t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
这个程序有点人为。为什么要做一个 all_gather
并将结果馈送到未映射的输出,而不是跳过主体中的 all_gather
并仅使用 out_specs=P('i')
收集结果?但即使是虚构的,这个例子仍然展示了一个不必要执行通信的转置(我们本可以执行一个非通信的切片),类似于示例 1 中的 psum
。
类似于 psum
示例,防御性的 psum_scatter
在某些情况下是必要的:
# Example 5: all_gather to a mapped output f5 = shmap(lambda x, y: all_gather(x, 'i') * y, in_specs=(P('i'), P('i')), out_specs=P('i')) # Currently we get this efficient transpose t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'), in_specs=(P('i'), P('i')), out_specs=P('i'))
那么我们如何避免这些低效的转置呢?
解决方案
这里有两个解决方案的想法。它们并不是互斥的。但是(剧透),第二个更好,并且它是我们所需的全部。
部分解决方案 “P-sum”:构建能够将 psum
表达到 out_specs
中的能力
这个解决方案有点像一个草人,因为它只会提供一个笨拙的编程方式。而且它甚至不能解决所有问题!但是,考虑到激励更完整的解决方案,这也值得一试。
上面的示例 4 是人为的,因为我们本可以在主体中使用 out_specs
而不是一个 all_gather
:
# Example 4 again f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P()) # Why didn't we just write it like this? f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better
版本没有任何转置问题,因为转置问题源于主体中的集体操作。
类似地,我们可以通过扩展 out_specs
来修复示例 1,以便它们可以表达求和:
# Example 1 again f1 = shmap(lambda x: psum(g(x), 'i'), in_specs=P('i'), out_specs=P()) # What if we could write an output sum like this? f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis # Then it could transpose like this: t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i')) t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,提供内置到 out_specs
的 psum
解决了示例 1 中的转置问题。但它并没有完全解决示例 3 中的被诅咒的恒等转置:
# Example 3 again cursed_identity = shmap(lambda x: x, P(), P()) # How it would transpose with the P-sum partial solution: t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i')) t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
尽管程序不会随着我们继续转置而继续增大,这是一个改进,但我们仍在进行浪费的通信。
完整解决方案:静态追踪设备变化与设备不变的中间值,以及新的基元
这个解决方案有两个组成部分:
- 追踪数值在特定网格轴上保证是设备不变还是设备变化的时机,
- 将
psum
分解为两步过程,引入一个新的pbroadcast
基元,并引入all_gather
及其转置的新基元。
从道义上讲,追踪设备不变与设备变化信息是一种类型级别的考虑。但为了第一次实现的方便起见,我们不需要在抽象值或者 jaxpr 类型中真正添加这些信息。在实施之前,我们会先使用类型引入这个想法。
同样将讨论如何使用户 API 既方便又向后兼容。但首先介绍这个想法时,我们会忽略方便性,而是尽可能地编写显式的代码。
在 avals(又称带名称的 avals,复活)中追踪设备不变性
有时候仅仅通过静态信息,我们就可以断定在shmap
的主体中一些中间变量的值在整个网格轴上是不变的,这意味着沿着网格轴的函数实例(及其对应的设备)必须都在使用相同的值进行计算。我们将这样的值称为设备不变的。对于那些不是设备不变的值,我们将它们称为设备变化的,尽管从类型系统的角度来看,我们其实是指它们可能在设备层面上是变化的。
要在类型中编码设备变化,我们将扩展数组类型的语法。我们会写类似x:f32[3,4]{i}
来表示x
在网格轴i
上(可能)是设备变化的(在shmap
的其他网格轴上是设备不变的)。更一般地说,我们会说数组类型语法的语法是这样的
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type> device_variance_type ::= {<axis_name>, ...}
我们还将更新类型规则来处理设备变化类型
- 对于除了集合之外的一阶基元
- 对于多元基元,操作数设备变化类型必须相等,形状也必须相等,例如
mul x:f32[s1]{r1} y:f32[s2][r2]
要求除了s1 == s2
外还要求r1 == r2
- 输出设备变化类型必须与操作数相同
- 对于高阶基元
- 我们只需实例化包括设备变化类型在内的任何类型变量(并检查类型是否相等,检查它们的设备变化类型是否相等)
- (当进行类型推断时,例如对
cond
的分支,我们会取设备变化类型中轴名称集合的并集)
- 对于第一阶集合
- 一个集合可以接受设备变化或设备不变的输入(沿着对应其轴名称参数的网格轴);将设备不变的操作数传递给接受设备变化操作数的集合,反之亦然,会导致错误
- 一个集合可以产生设备变化或设备不变的输出
- 请看下面的表格 作为一个附带的好处,任何实现此类型检查的逻辑都可以包含
shmap
的“静态分析”检查,以确定任何未映射的out_specs
是否与其兼容。
这里是一个总结集体原语设备差异类型的表格:
名称 | 设备差异类型 | 示例 | 降低到 HLO | 转置 |
psum2 |
可变 -> 不变 |
y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i') |
AllReduceSum (通讯) |
pbroadcast |
pbroadcast |
不变 -> 可变 |
y:f32[3]{i} = pbroadcast(x:f32[3], 'i') |
no-op(无通讯) | psum |
all_to_all |
可变 -> 可变 |
y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通讯) |
all_to_all |
|
axis_index |
() -> 可变 |
idx:i32[]{i} = axis_index('i') |
ReplicaId 和一些算术运算(无通讯) |
n/a |
psum_scatter |
可变 -> 可变 |
y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i') |
ReduceScatterSum (通讯) |
all_gather |
all_gather |
可变 -> 可变 |
y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i') |
AllGather (通讯) |
psum_scatter |
pscatter |
不变 -> 可变 |
y:f32[2]{i} = pscatter(x:f32[16], 'i') |
lambda x: x[axis_index('i'), None] (无通讯) |
all_gather_invariant |
all_gather_invariant |
可变 -> 不变 |
y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i') |
AllGather (通讯) |
pscatter |
这里有一些令人惊讶的事情!
- 我们引入了几个新的原语,包括
pbroadcast
,有趣的是降低为 no-opall_gather_invariant
,它降低到与all_gather
相同的内容,但具有不同的设备差异类型(实质上all_gather
中融合了pbroadcast
,而all_gather_invariant
没有)pscatter
,它是all_gather_invariant
的对偶(转置)
all_gather
有一个设备可变的结果
直觉上,引入 pbroadcast
的原因(除了使类型规则生效之外)是为了使 psum
能转置为物理上的 no-op。我们需要 all_gather
有一个设备可变的结果,这样我们就可以将其转置为 psum_scatter
;如果我们将其留在设备不变的结果上,可能需要下游的 pbroadcast
,这种组合将转置为低效的 psum
,然后是切片 / pscatter
。因此,我们将 pbroadcast
“融合到” all_gather
中,从而实现有效的转置为 psum_scatter
。我们提供 all_gather_invariant
及其转置 pscatter
主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,可以使用 out_specs
进行不同写作)。
有趣的是,psum
和 pbroadcast
的转置对应于用户在训练 LLMs 时引入的 pmap
中的 psum_idrev
和 id_psumrev
。
这个系统是如何解决低效转置示例的
再次考虑简化的激励示例:
# Example 1 again f1 = shmap(lambda x: psum(g(x), 'i'), in_specs=P('i'), out_specs=P()) # Example 1 with intermediate device variance types annotated @partial(shmap, in_specs=P('i'), out_specs=P()) def f1(x: f32[3,4]{i}): w:f32[]{i} = g(x) y:f32[]{} = psum(w, 'i') return y
使用这些新规则,转置为:
# Example 1 transpose using device variance types (go ahead and transpose this again!) t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')), in_specs=P(), out_specs=P('i')) # Example 1 transpose with intermediate device variance types annotated @partial(shmap, in_specs=P('i'), out_specs=P()) def f1_transpose(ybar: f32[]): wbar:f32[]{i} = pbroadcast(ybar, 'i') xbar:f32[3,4]{i} = transpose(g)(wbar) return xbar
在评估 pbroadcast
应用程序时完全不涉及通信或 FLOP;这是一个无操作。请注意,如果我们保持转置,主体的大小不会增长;确实 t(t(f1)) == f1
。实现了效率!
只要我们在需要时插入 pbroadcast
以进行类型检查,我们就不会搞砸其他示例:
# Example 2 rewritten with explicit pbroadcast f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y, in_specs=(P('i'), P('i')), out_specs=P('i')) # Example 2 transpose using device variance types t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')), in_specs=(P('i'), P('i')), out_specs=P('i')) # Example 3 again cursed_identity = shmap(lambda x: x, P(), P()) # Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type. # Example 3 transpose using device variance types t(cursed_identity) = shmap(lambda x: x, P(), P()) t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直观地,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。
对于 all_gather
示例,示例 4 将需要使用 all_reduce_invariant
来实现有效的转置(虽然最好是在主体中使用 out_specs
而不是集体操作):
# Example 4 rewritten with explicit all_reduce_invariant f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P()) # Example 4 with intermediate device variance types annotated @partial(shmap, P('i'), P()) def f4(x:f32[1]{i}): y:f32[8]{} = all_gather_invariant(x, 'i') return y # Example 4 transpose with intermediate device variance types annotated @partial(shmap, in_specs=P(), out_specs=P('i')) def f4_transpose(ybar:f32[8]): xbar:f32[1]{i} = pscatter(ybar, 'i') return xbar
对于示例 5,使用设备变化的 all_gather
的效果与我们期望的一样:
# Example 5 with intermediate device variance types annotated @partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i')) def f5(x:f32[1]{i}, y:f32[8]{i}): z:f32[8]{i} = all_gather(x, 'i') w:f32[8]{i} = z * y return w # Transpose with respect to first argument @partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i')) def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}): zbar:f32[8]{i} = wbar * y xbar:f32[1]{i} = psum_scatter(zbar, 'i') return xbar
如何使 API 对用户方便(并保持向后兼容)
但是,有哪位用户愿意编写pbroadcast
?有哪位开发人员愿意破坏许多现有用户代码,其中包括未输入到未映射输出的 psum
?不包括我!
相反,我们可以自动插入pbroadcast
。这有点类似于我们在 jax.numpy
层执行自动等级提升时的方式,插入广播以避免二元运算符中的等级不匹配错误。但它要简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多元操作,其中操作数在设备方差类型上存在差异时,我们将操作数的设备方差类型的轴名称集合的并集,并插入pbroadcast
以将每个操作数提升到结果设备方差类型。
在需要之前自动插入 pbroadcast
可能意味着我们对相同的操作数多次应用相同的 pbroadcast
,从而创建共同子表达式。当我们转置时,这些可能会变成 psum
的和而不是 psum
的总和。我们将依赖编译器根据需要进行清理。如果这是个问题,我们可以向 pbroadcast
插入通行证添加一些简单的记忆化处理。
all_gather
的用户 API 将默认为 all_gather_p
(而不是 all_gather_invariant_p
),涵盖常见情况,意味着不需要插入 pbroadcast
。
我们可以在 shmap
上提供一个选项来禁用这种自动插入pbroadcast
,在这种情况下,用户需要确保类型正确。这种显式选项可能对一些人很有吸引力,他们希望明确指定向后传递中 psum
出现的位置。
如何实现解决方案
使实现轻量级的关键是我们不会将这些类型添加到 avals 或 jaxprs 中。至少起初不会。这可能很昂贵,因为它需要更新 JAX 的其余部分,例如 avals 和 jaxprs 的所有消费者可能需要处理新类型。我们不会再次上当!
相反,我们将保留这些扩展类型作为shmap
的内部元数据,就像当前的“out_specs
复制检查”机制一样。实际上,这个解决方案相当于对现有机制的相对小的扩展:它已经在跟踪相同的信息;现在我们只是添加了pbroadcast
。
我们至少有两种选择来执行pbroadcast
插入的位置:
- 就在转置之前,在转置规则中,我们有了计算的 jaxpr;
- 在每个
shmap
主体中,无论是急切执行还是分阶段输出,都要像当前的“out_specs
复制检查”机制一样。前者可能更容易,因为我们只需要处理 jaxpr 案例,并且只有线性原语。但我们将首先尝试后者,以便此处的实现是对现有复制检查逻辑的严格修订/扩展。
附录:定义和激励具有未映射输入和输出的映射
对于具体性,我们将主要关注shmap
,尽管这些想法同样适用于例如pmap
和可能的xmap
。
当对应的in_specs
条目未提及该网格轴的名称时,参数/输入沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例对于参数得到相同的值。对于调用者来说,每个操作数根据其映射的网格轴进行切片,而对于未映射的网格轴,则没有切片。
当对应的out_specs
条目未提及该网格轴的名称时,输出沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例必须返回相同的值。对于调用者来说,shmap
的每个结果由沿着输出映射的每个函数实例的返回值串联而成,而对于未映射的网格轴,则只使用该值的一个副本。
参见《shmap
JEP》,其中展示了未映射输入和输出的示例。作为比较,在vmap
中,未映射的输入/输出通过使用in_axes
/ out_axes
为None
(而不是int
)来指示。
这里是我们喜欢shmap
的未映射输入和输出的原因:
- 与
pjit
相同的表达能力。 任何pjit
能做的事情,shmap
逃逸通道也应该能做到。否则我们就会缺少逃逸通道!如果shmap
中没有未映射的输出,那么我们无法表达与pjit
相同的批并行损失函数计算。 - 闭合输入。 闭合的输入实际上对应于未映射的输入,以及…
- 转置闭包。 一旦我们有了未映射的输入,将其转置到未映射的输出就是很自然的事情。
因此,未映射的输出既是规范的又是有用的!
JAX 中文文档(十二)(5)https://developer.aliyun.com/article/1559719