JAX 中文文档(十二)(4)

简介: JAX 中文文档(十二)

JAX 中文文档(十二)(3)https://developer.aliyun.com/article/1559716

jax.extend:一个用于扩展的模块

原文:jax.readthedocs.io/en/latest/jep/15856-jex.html

@froystig, @sharadmv, @jakevdp, @yashk2810

2023 年 5 月

import jax.extend as jex 

多个项目依赖于 JAX 的代码库内部,通常用于使用其核心机制(例如编写其 IR 上的转换)或扩展它(例如定义新的原语)。这些依赖的两个挑战是(a)我们的内部结构并不都是为外部使用而设计的,以及(b)绕过 JAX 的公共 API 是不受支持的。换句话说,我们的内部经常被用作库,但既不像库那样结构化也不像库那样更新。

此提案考虑引入一个jax.extend模块,定义 JAX 一些内部组件的库视图。我们将其视为第二层 API,仍然基本不保证兼容性政策,但希望在发生更改时更容易发现。

jax.extend的受众包括与 JAX 相关的 Python 库,如Oryxjax-triton等,以及进行函数转换、自动微分系统、数值编程编译器前端等实验的项目。

本说明概述了jax.extend现在和将来可能的样子。它没有详细列出所有细节,而是建议我们开始逐步开发这个模块。

注意,jax.extendjax.experimental不同,后者是新功能和正在进行的想法的一个暂存场所。通常,jax.experimental中的工作最终会进入另一个 JAX 模块或被完全移除。

没有兼容性政策

为了保持开发的开销低,jax.extend不会遵循公共API 兼容性政策。它将承诺没有弃用窗口,也没有版本间的向后兼容性。每个发布都可能会破坏现有的调用者,没有简单的回退措施(例如没有重新引入先前行为的标志)。我们将依赖变更日志来指出这些更改。

调用jax.extend的调用者可能会发现在 JAX 发布时与其代码一起定期升级对他们有用。这是当今依赖 JAX 内部的项目的一个常见习惯。不同之处在于现在它将以更好的意图和更好的库设计和命名帮助中,伴随着变更日志公告的形式出现。

逐步开发

没有兼容性政策使得在实施上更容易上手:第一天,我们可以从内部包(如jax._src)中移植少量符号到今天的jax.corejax.interpreters。然后我们可以迭代改进。

可能的模块概述

我们可以设想,最终jax.extend可能包括以下模块:

  • core – 原语,Jaxpr IR 等。
  • interpreters – 核心转换(例如自动微分、批处理)和降低。
  • random – 随机位生成、关键分割和折叠、关键数组。
  • sharding – 关于分布式数组的额外功能。

最初在模块中可能还有其他符号,例如jex.api_util,随着我们的工作,我们将移除或替换它们。其他的时间会决定。例如,jex.lib可能在短期内提供访问 jexlib 的入口点,但是目前还不清楚我们是否想长期保留它。

对每个这些内容可能包含的一些初步想法的追踪。

jax.extend.core

这应该至少使调用者能够定义新的 JAX 原语并处理 Jaxpr IR(jax.make_jaxpr(...)的输出)。支持这一点可能涉及提供:

  • 访问现有的核心系统原语,例如今天的jax._src.lax.add_p
  • 访问 IR 类型,例如当前的jax._src.core.ShapedArray
  • 用于检查和漂亮打印 jaxprs 的功能。
  • 明确构建 jaxprs 的功能,而不是通过jax.make_jaxpr分阶段地阶段 Python 函数(或不阶段化!)。

在初始化时,这个模块将包含比定义原语和规则所需更多的符号,包括在设置“最终风格转换”时使用的各种名称,例如当前的jax._src.core.TraceTracer类。我们可以重新审视jex.core是否应该支持初始风格方法以及是否可以通过比完全暴露TraceTracer更狭窄的 API 来支持最终风格扩展。Oryx可能会帮助指导这些决策。

我们还可以考虑将make_jaxpr本身迁移到jax.core中。

jax.extend.interpreters

此模块将提供注册各种原语转换规则的手段 —— 定义它们在自动微分、批处理、降低等方面的行为。

最初将反映jax._src.interpreters,提供模块adbatchingpartial_eval(用于将 Python 编程转换为 Jaxpr,并用于自动微分中的线性化)、mlirpxlaxla。前三者可能可以通过jax.core中的单一原语扩展 API 替换。用于降低的后三者可以简化为一个模块,也许。

今天,为了编写转换规则,例如用于自动微分和批处理的规则,调用者可能需要与跟踪器相关的符号,例如JVPTracerBatchTracer。以后可能可以避免这种情况,并允许我们从jax中移除跟踪器类型。

这个模块加上jex.core应该足以复制今天的自定义原语教程(例如我们的教程dfm 的教程)。例如,定义一个原语及其在jax.jit下的行为可能如下(在短期内):

from jax.extend import core          # Previously: from jax import core
from jax.extend.interpreters import mlir        # ... and similarly
mul_add_p = core.Primitive('mul_add')
mul_add_p.def_impl(lambda x, y, z: x * y + z)
@mul_add_p.def_abstract_eval
def mul_add_abstract(x_sa, y_sa, z_sa):
  return core.ShapedArray(x_sa.shape, x_sa.dtype)
def mul_add_mlir(ctx, xc, yc, zc):
  add = mlir.hlo.AddOp
  mul = mlir.hlo.MulOp
  return add(mul(xc, yc), zc).results
mlir.register_lowering(mul_add_p, mul_add_mlir)
import jax
print(mul_add_p.bind(2, 3, 4))            # -> 10
print(jax.jit(mul_add_p.bind)(2, 3, 4))   # -> Array(10, dtype=int32)

jax.extend.random

这个模块可以暴露出我们定义新的随机数生成器实现的机制,并提供用于处理 PRNG 密钥内部的函数(参见问题#9263),例如当前的jax._src.prng.random_wraprandom_unwrap

它还可以暴露出构成内置随机数生成器实现基础的键控哈希函数,例如jax._src.prng.threefry_2x32

jax.extend.sharding

这个模块可以暴露出用于分片分布式数组的低级实用工具。

目前我们只考虑了一项。XLA 编译器的数组分片格式比JAX 提供的那些更具表现力。我们可以将其作为jex.sharding.XlaOpShardingProto提供,对应于今天内部的jax._src.lib.xla_client.OpSharding

复制引发收集的有效转置

jax.readthedocs.io/en/latest/jep/17111-shmap-transpose.html

mattjj@dougalm@

2023 年 8 月

动机

我们在自动转置包含某些收集的shmap中遇到了效率问题。问题出现在psumall_gather,特别是当收集的输出作为未映射的输出返回给调用者时。这并不是一个边缘情况:例如,在应用grad到基于shmap的批量数据并行神经网络损失函数时,使用psum来计算总损失。

我们已经知道这个问题有一段时间了。与pmap类似的问题存在,尽管通过在pmap内部而不是外部保留grad来解决了这个问题。不完全的带有名称的avals-with-names工作的一个主要目标是解决这个转置效率问题的一个版本。这篇文档借鉴了这些想法,同时对其进行了扩展和修订,以处理更多情况,并且更易于落地。事实上,这里提出的解决方案只影响shmap的实现。其余系统不需要更改(暂时)。

这篇文档的主要目的是定义这个转置效率问题,并提出一个易于落地的解决方案。

这篇文档不涉及:

  • 数组上的逻辑轴名称(这里的唯一轴名称与shmap和 OG pmap中的轴名称一样);
  • 更改自动微分语义(所有数字和(非)错误保持不变,我们只是提高效率);
  • 允许用户代码反映任何新信息,或者实际上根本不影响用户代码。

问题:psumall_gather的有效转置取决于共享设备上的余切是否不变

考虑这个半真实的例子,旨在类似于一个复制参数批量数据并行损失函数:

devices = jax.devices()  # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
         in_specs=(P(None, None), P('batch', None)),
         out_specs=P())
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
  global_loss = lax.pmean(local_loss, 'batch'))
  return global_loss 

注意out_specs=P(),它指示未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。

loss示例中的大多数细节并不重要。对于我们的目的来说,唯一重要的是我们在最后应用了psum(或者更确切地说是pmean = lambda x, name: psum(x, name) / psum(1, name))。因此,一个精简版本看起来像这样:

# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P()) 

甚至通过抑制mesh参数简化了符号。在接下来的例子中,可以从上下文中推断出来。

什么样的转置看起来像?写t来表示函数转置,我们可以通过应用下面的函数¿f1_transpose?有效地评估任意ybar对应的t(f1)(ybar)

# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i')) 

但这并不是我们当前获得的转置t(f1)

相反,当前的转置配方大致是我们交换in_specsout_specs,对未映射输出进行一些分区重缩放,并转置主体。因为psum本身是其自身的转置(作为全归约和的总和),我们最终会产生这个转置:

# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
              in_specs=P(), out_specs=P('i')) 

这个转置虽然得到了正确的数字,但是很浪费。我们从转置的 in_specs=P() 静态地知道 ybar 对于每个函数实例都具有相同的值,即其值对于沿着被命名为 i 的网格轴的设备是不变的,然而我们还是对它应用了 psum!这使用了昂贵的通信来将每个设备上的值乘以 8。(这里的 8 指的是轴 i 的大小。除以 8 来自于原始函数的 out_specs=P();它和微不足道的 psum 基本上互相抵消了。)

我们做错了什么?我们没有利用 cotangents ybar 对应于 f1 的未映射输出是设备不变的这一事实;相反,我们像防御性地 psum 它们一样处理它们,就像 psum 的转置不能确定它们一样。有时 psum 是必要的,比如对于关于其第一个参数的 f2 的转置:

# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
          in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
                in_specs=(P('i'), P('i')), out_specs=P('i')) 

直观地说,如果我们的转置机制能区分示例 1 和示例 2,我们可以通过尽可能避免在可能的情况下避免 psum 和除法来做得更好。

低效的示例甚至可以更小。考虑转置这个被诅咒的恒等函数:

# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
... 

随着我们的转置越来越多,它变得越来越大。真丢人!

psum 并不是唯一的问题。类似的情况也适用于 all_gather

# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i')) 

这个程序有点人为。为什么要做一个 all_gather 并将结果馈送到未映射的输出,而不是跳过主体中的 all_gather 并仅使用 out_specs=P('i') 收集结果?但即使是虚构的,这个例子仍然展示了一个不必要执行通信的转置(我们本可以执行一个非通信的切片),类似于示例 1 中的 psum

类似于 psum 示例,防御性的 psum_scatter 在某些情况下是必要的:

# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
                 in_specs=(P('i'), P('i')), out_specs=P('i')) 

那么我们如何避免这些低效的转置呢?

解决方案

这里有两个解决方案的想法。它们并不是互斥的。但是(剧透),第二个更好,并且它是我们所需的全部。

部分解决方案 “P-sum”:构建能够将 psum 表达到 out_specs 中的能力

这个解决方案有点像一个草人,因为它只会提供一个笨拙的编程方式。而且它甚至不能解决所有问题!但是,考虑到激励更完整的解决方案,这也值得一试。

上面的示例 4 是人为的,因为我们本可以在主体中使用 out_specs 而不是一个 all_gather

# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i')) 

f4_better 版本没有任何转置问题,因为转置问题源于主体中的集体操作。

类似地,我们可以通过扩展 out_specs 来修复示例 1,以便它们可以表达求和:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i'))  # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i')) 

因此,提供内置到 out_specspsum 解决了示例 1 中的转置问题。但它并没有完全解决示例 3 中的被诅咒的恒等转置:

# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i')) 

尽管程序不会随着我们继续转置而继续增大,这是一个改进,但我们仍在进行浪费的通信。

完整解决方案:静态追踪设备变化与设备不变的中间值,以及新的基元

这个解决方案有两个组成部分:

  1. 追踪数值在特定网格轴上保证是设备不变还是设备变化的时机,
  2. psum分解为两步过程,引入一个新的pbroadcast基元,并引入all_gather及其转置的新基元。

从道义上讲,追踪设备不变与设备变化信息是一种类型级别的考虑。但为了第一次实现的方便起见,我们不需要在抽象值或者 jaxpr 类型中真正添加这些信息。在实施之前,我们会先使用类型引入这个想法。

同样将讨论如何使用户 API 既方便又向后兼容。但首先介绍这个想法时,我们会忽略方便性,而是尽可能地编写显式的代码。

在 avals(又称带名称的 avals,复活)中追踪设备不变性

有时候仅仅通过静态信息,我们就可以断定在shmap的主体中一些中间变量的值在整个网格轴上是不变的,这意味着沿着网格轴的函数实例(及其对应的设备)必须都在使用相同的值进行计算。我们将这样的值称为设备不变的。对于那些不是设备不变的值,我们将它们称为设备变化的,尽管从类型系统的角度来看,我们其实是指它们可能在设备层面上是变化的。

要在类型中编码设备变化,我们将扩展数组类型的语法。我们会写类似x:f32[3,4]{i}来表示x在网格轴i上(可能)是设备变化的(在shmap的其他网格轴上是设备不变的)。更一般地说,我们会说数组类型语法的语法是这样的

shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...} 

我们还将更新类型规则来处理设备变化类型

  • 对于除了集合之外的一阶基元
  • 对于多元基元,操作数设备变化类型必须相等,形状也必须相等,例如mul x:f32[s1]{r1} y:f32[s2][r2]要求除了s1 == s2外还要求r1 == r2
  • 输出设备变化类型必须与操作数相同
  • 对于高阶基元
  • 我们只需实例化包括设备变化类型在内的任何类型变量(并检查类型是否相等,检查它们的设备变化类型是否相等)
  • (当进行类型推断时,例如对cond的分支,我们会取设备变化类型中轴名称集合的并集)
  • 对于第一阶集合
  • 一个集合可以接受设备变化或设备不变的输入(沿着对应其轴名称参数的网格轴);将设备不变的操作数传递给接受设备变化操作数的集合,反之亦然,会导致错误
  • 一个集合可以产生设备变化或设备不变的输出
  • 请看下面的表格 作为一个附带的好处,任何实现此类型检查的逻辑都可以包含 shmap 的“静态分析”检查,以确定任何未映射的 out_specs 是否与其兼容。

这里是一个总结集体原语设备差异类型的表格:

名称 设备差异类型 示例 降低到 HLO 转置
psum2 可变 -> 不变 y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i') AllReduceSum (通讯) pbroadcast
pbroadcast 不变 -> 可变 y:f32[3]{i} = pbroadcast(x:f32[3], 'i') no-op(无通讯) psum
all_to_all 可变 -> 可变 y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0) AllToAll (通讯) all_to_all
axis_index () -> 可变 idx:i32[]{i} = axis_index('i') ReplicaId 和一些算术运算(无通讯) n/a
psum_scatter 可变 -> 可变 y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i') ReduceScatterSum (通讯) all_gather
all_gather 可变 -> 可变 y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i') AllGather (通讯) psum_scatter
pscatter 不变 -> 可变 y:f32[2]{i} = pscatter(x:f32[16], 'i') lambda x: x[axis_index('i'), None] (无通讯) all_gather_invariant
all_gather_invariant 可变 -> 不变 y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i') AllGather (通讯) pscatter

这里有一些令人惊讶的事情!

  • 我们引入了几个新的原语,包括
  • pbroadcast,有趣的是降低为 no-op
  • all_gather_invariant,它降低到与 all_gather 相同的内容,但具有不同的设备差异类型(实质上 all_gather 中融合了 pbroadcast,而 all_gather_invariant 没有)
  • pscatter,它是 all_gather_invariant 的对偶(转置)
  • all_gather 有一个设备可变的结果

直觉上,引入 pbroadcast 的原因(除了使类型规则生效之外)是为了使 psum 能转置为物理上的 no-op。我们需要 all_gather 有一个设备可变的结果,这样我们就可以将其转置为 psum_scatter;如果我们将其留在设备不变的结果上,可能需要下游的 pbroadcast,这种组合将转置为低效的 psum,然后是切片 / pscatter。因此,我们将 pbroadcast “融合到” all_gather 中,从而实现有效的转置为 psum_scatter。我们提供 all_gather_invariant 及其转置 pscatter 主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,可以使用 out_specs 进行不同写作)。

有趣的是,psumpbroadcast 的转置对应于用户在训练 LLMs 时引入的 pmap 中的 psum_idrevid_psumrev

这个系统是如何解决低效转置示例的

再次考虑简化的激励示例:

# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
           in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
  w:f32[]{i} = g(x)
  y:f32[]{} = psum(w, 'i')
  return y 

使用这些新规则,转置为:

# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
              in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
  wbar:f32[]{i} = pbroadcast(ybar, 'i')
  xbar:f32[3,4]{i} = transpose(g)(wbar)
  return xbar 

在评估 pbroadcast 应用程序时完全不涉及通信或 FLOP;这是一个无操作。请注意,如果我们保持转置,主体的大小不会增长;确实 t(t(f1)) == f1。实现了效率!

只要我们在需要时插入 pbroadcast 以进行类型检查,我们就不会搞砸其他示例:

# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
           in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
                 in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P()) 

直观地,在示例 1 中,我们现在只有“原始 psum 的一半”,而在示例 2 中,我们得到了“两半”。对于示例 3,我们根本不需要主体中的任何操作。

对于 all_gather 示例,示例 4 将需要使用 all_reduce_invariant 来实现有效的转置(虽然最好是在主体中使用 out_specs 而不是集体操作):

# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
  y:f32[8]{} = all_gather_invariant(x, 'i')
  return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
  xbar:f32[1]{i} = pscatter(ybar, 'i')
  return xbar 

对于示例 5,使用设备变化的 all_gather 的效果与我们期望的一样:

# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
  z:f32[8]{i} = all_gather(x, 'i')
  w:f32[8]{i} = z * y
  return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
  zbar:f32[8]{i} = wbar * y
  xbar:f32[1]{i} = psum_scatter(zbar, 'i')
  return xbar 

如何使 API 对用户方便(并保持向后兼容)

但是,有哪位用户愿意编写pbroadcast?有哪位开发人员愿意破坏许多现有用户代码,其中包括未输入到未映射输出的 psum?不包括我!

相反,我们可以自动插入pbroadcast。这有点类似于我们在 jax.numpy 层执行自动等级提升时的方式,插入广播以避免二元运算符中的等级不匹配错误。但它要简单得多,因为我们不需要处理形状元组。典型的规则是:每当我们看到一个多元操作,其中操作数在设备方差类型上存在差异时,我们将操作数的设备方差类型的轴名称集合的并集,并插入pbroadcast以将每个操作数提升到结果设备方差类型。

在需要之前自动插入 pbroadcast 可能意味着我们对相同的操作数多次应用相同的 pbroadcast,从而创建共同子表达式。当我们转置时,这些可能会变成 psum 的和而不是 psum 的总和。我们将依赖编译器根据需要进行清理。如果这是个问题,我们可以向 pbroadcast 插入通行证添加一些简单的记忆化处理。

all_gather 的用户 API 将默认为 all_gather_p(而不是 all_gather_invariant_p),涵盖常见情况,意味着不需要插入 pbroadcast

我们可以在 shmap 上提供一个选项来禁用这种自动插入pbroadcast,在这种情况下,用户需要确保类型正确。这种显式选项可能对一些人很有吸引力,他们希望明确指定向后传递中 psum 出现的位置。

如何实现解决方案

使实现轻量级的关键是我们不会将这些类型添加到 avals 或 jaxprs 中。至少起初不会。这可能很昂贵,因为它需要更新 JAX 的其余部分,例如 avals 和 jaxprs 的所有消费者可能需要处理新类型。我们不会再次上当!

相反,我们将保留这些扩展类型作为shmap的内部元数据,就像当前的“out_specs复制检查”机制一样。实际上,这个解决方案相当于对现有机制的相对小的扩展:它已经在跟踪相同的信息;现在我们只是添加了pbroadcast

我们至少有两种选择来执行pbroadcast插入的位置:

  1. 就在转置之前,在转置规则中,我们有了计算的 jaxpr;
  2. 在每个shmap主体中,无论是急切执行还是分阶段输出,都要像当前的“out_specs复制检查”机制一样。前者可能更容易,因为我们只需要处理 jaxpr 案例,并且只有线性原语。但我们将首先尝试后者,以便此处的实现是对现有复制检查逻辑的严格修订/扩展。

附录:定义和激励具有未映射输入和输出的映射

对于具体性,我们将主要关注shmap,尽管这些想法同样适用于例如pmap和可能的xmap

当对应的in_specs条目未提及该网格轴的名称时,参数/输入沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例对于参数得到相同的值。对于调用者来说,每个操作数根据其映射的网格轴进行切片,而对于未映射的网格轴,则没有切片。

当对应的out_specs条目未提及该网格轴的名称时,输出沿着网格轴是未映射的。逻辑上意味着每个沿着该网格轴的函数实例必须返回相同的值。对于调用者来说,shmap的每个结果由沿着输出映射的每个函数实例的返回值串联而成,而对于未映射的网格轴,则只使用该值的一个副本。

参见shmap JEP》,其中展示了未映射输入和输出的示例。作为比较,在vmap中,未映射的输入/输出通过使用in_axes / out_axesNone(而不是int)来指示。

这里是我们喜欢shmap的未映射输入和输出的原因:

  • pjit相同的表达能力。 任何pjit能做的事情,shmap逃逸通道也应该能做到。否则我们就会缺少逃逸通道!如果shmap中没有未映射的输出,那么我们无法表达与pjit相同的批并行损失函数计算。
  • 闭合输入。 闭合的输入实际上对应于未映射的输入,以及…
  • 转置闭包。 一旦我们有了未映射的输入,将其转置到未映射的输出就是很自然的事情。

因此,未映射的输出既是规范的又是有用的!


JAX 中文文档(十二)(5)https://developer.aliyun.com/article/1559719

相关文章
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
51 2
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
44 2
|
4月前
|
IDE API 开发工具
JAX 中文文档(十二)(2)
JAX 中文文档(十二)
36 0
|
4月前
|
编译器 API 调度
JAX 中文文档(十二)(1)
JAX 中文文档(十二)
61 0
|
4月前
|
算法 编译器 API
JAX 中文文档(十二)(3)
JAX 中文文档(十二)
34 0