使用 shard_map
的 SPMD 多设备并行性
shard_map
是一种单程序多数据(SPMD)多设备并行性 API,用于在数据分片上映射函数。映射的函数应用或实例通过显式的集合通信操作进行通信。
shard_map
是与 jit
内置的自动编译器并行化互补且可组合的。使用 jit
,你编写的代码就像为单个设备编写的一样,并且编译器可以自动将计算分区到多个设备上,在幕后生成每个设备的代码和通信集合。使用 shard_map
,你可以控制自己的分区代码和显式集合。或者你可以同时进行一些操作:在组设备中手动控制同时保留组内设备分区给编译器。这两种方法可以根据需要混合、匹配和组合。
如果您熟悉 pmap
,可以将 shard_map
视为其演进。它更具表现力、性能和与其他 JAX API 可组合。它甚至可以急切地工作,更易于调试!(更多信息,请参阅与 pmap
的详细比较。)
通过阅读本教程,您将学习如何使用 shard_map
来完全控制您的多设备代码。您将详细了解它如何与 jax.jit
的自动并行化和 jax.grad
的自动微分结合使用。我们还将给出一些神经网络并行化策略的基本示例。
我们假设本教程在具有八个设备的环境中运行:
import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
所以,让我们来看一个 shard_map
吧!
不多说了,这里是一个玩具例子:
from functools import partial import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map
devices = mesh_utils.create_device_mesh((4, 2)) mesh = Mesh(devices, axis_names=('x', 'y')) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) @partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)), out_specs=P('x', None)) def matmul_basic(a_block, b_block): # a_block: f32[2, 8] # b_block: f32[8, 4] c_partialsum = jnp.dot(a_block, b_block) c_block = jax.lax.psum(c_partialsum, 'y') # c_block: f32[2, 4] return c_block c = matmul_basic(a, b) # c: f32[8, 4]
这个函数通过执行本地块矩阵乘法,然后进行集合求和操作来并行计算矩阵乘积。我们可以检查结果是否正确:
from jax.tree_util import tree_map, tree_all def allclose(a, b): return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)) allclose(c, jnp.dot(a, b))
True
结果沿其行被分片:
jax.debug.visualize_array_sharding(c)
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
在高层次上,shard_map
在某种程度上类似于 vmap
或 pmap
,因为我们在数组数据的部分上映射函数,但请注意
shard_map
将输入切片成块(输出由连接结果块形成),保持秩不变,而vmap
则通过映射掉一个轴来减少秩;mesh
参数允许我们控制计算和结果的精确设备放置;- 我们同时映射多个数据轴,并设置多个轴名称以进行集合操作(这里有
'x'
和'y'
); - 因为我们还没有使用
jax.jit
,一切都是急切地评估的,我们甚至可以打印中间值以进行调试。
上述代码执行与此 jax.jit
自动并行化代码相同的计算:
from jax.sharding import NamedSharding a = jax.device_put(a, NamedSharding(mesh, P('x', 'y'))) b = jax.device_put(b, NamedSharding(mesh, P('y', None))) @jax.jit def matmul_reference(a, b): c = jnp.dot(a, b) return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None))) c_ref = matmul_reference(a, b) allclose(c_ref, jnp.dot(a, b))
True
我们可以将 shard_map
视为根据其 mesh
和 in_specs
参数在其输入上执行 device_put
或 with_sharding_constraint
,因此 matmul_basic
操作的块与 matmul_reference
中的相同:
print('a blocks:'); jax.debug.visualize_array_sharding(a) print('b blocks:'); jax.debug.visualize_array_sharding(b) print('c blocks:'); jax.debug.visualize_array_sharding(c)
a blocks: b blocks: c blocks:
CPU 0CPU 1 CPU 2CPU 3 CPU 4CPU 5 CPU 6CPU 7
CPU 0,2,4,6 CPU 1,3,5,7
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
放慢速度,从基础开始!
降维与保持秩的映射
我们可以将 vmap
和 pmap
看作是沿轴(例如将 2D 矩阵解包成其 1D 行)对每个数组输入应用其主体函数,然后将结果堆叠在一起,至少在不涉及集合操作时是这样的:
def check_vmap(f, xs): ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs) expected = jnp.stack([f(x) for x in xs]) # vmap reference semantics print(allclose(ans, expected)) check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3))
True
例如,如果 xs
的形状为 f32[8,5]
,那么每个 x
的形状将为 f32[5]
,如果每个 f(x)
的形状为 f32[3,7]
,那么最终堆叠的结果 vmap(f)(xs)
的形状将为 f32[8,3,7]
。也就是说,函数 f
的每个应用都以比 vmap(f)
对应参数少一个轴的输入作为参数。我们可以说这些是降维映射,输入/输出的解包/堆叠。
函数 f
的逻辑应用数量,或称为 f
的实例,取决于被映射输入轴的大小:例如,如果我们映射一个大小为 8 的输入轴,语义上我们得到函数的 8 个逻辑应用。
相比之下,shard_map
并没有这种降维行为。相反,我们可以将其视为沿输入轴切片(或“取消连接”)成块,应用主体函数,然后将结果再次连接在一起(同样是在不涉及集合操作时):
import numpy as np devices = np.array(jax.devices()[:4]) mesh = Mesh(devices, ('i',)) # mesh.shape['i'] = 4 def check_shmap(f, y): ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y) expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])]) print(allclose(ans, expected)) check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4))
True
回想一下,jnp.split
将其输入切片为相同大小的块,以便如果在上述示例中 y
的形状为 f32[8,5]
,那么每个 y_blk
的形状将为 f32[2,5]
,如果每个 f(y_blk)
的形状为 f32[3,7]
,那么最终连接的结果 shard_map(f, ...)(y)
的形状将为 f32[12,7]
。因此,shard_map
对其输入进行保持秩的映射,输入/输出的取消连接/连接。
函数 f
的逻辑应用数量由网格大小决定,而不是任何输入轴的大小:例如,如果我们有总大小为 4 的网格(即在 4 个设备上),那么语义上我们得到函数的 4 个逻辑应用,对应于物理计算这些函数的 4 个设备。
控制每个输入如何分割(取消连接)并与 in_specs
平铺
每个 in_specs
通过 PartitionSpec
标识某些对应输入数组轴的网格轴名称,表示如何将该输入分割(或解串联)为应用体函数的块。该标识确定了碎片大小;当输入轴与网格轴标识为同一时,输入沿该逻辑轴分割(解串联)为数目等于相应网格轴大小的片段。(如果相应的网格轴大小不能整除输入数组轴大小,则出错。)如果输入的 pspec
未提及网格轴名称,则在该网格轴上没有分割。例如:
devices = mesh_utils.create_device_mesh((4, 2)) mesh = Mesh(devices, ('i', 'j')) @partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): print(x_block.shape) # prints (3, 12) return x_block x1 = jnp.arange(12 * 12).reshape(12, 12) y = f1(x1)
(3, 12)
在这里,因为输入 pspec
未提及网格轴名称 'j'
,因此没有输入数组轴沿该网格轴进行分割;类似地,因为输入数组的第二轴没有标识(因此没有沿任何网格轴分割),f1
的应用获得了沿该轴的完整视图。
当输入 pspec
中未提及网格轴时,我们可以始终重写为一个效率较低的程序,其中所有网格轴都被提及,但调用者执行 jnp.tile
,例如:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j')) def f2(x_block): print(x_block.shape) return x_block x = jnp.arange(12 * 12).reshape(12, 12) x_ = jnp.tile(x, (1, mesh.shape['j'])) # x_ has shape (12, 24) y = f2(x_) # prints (3,12), and f1(x) == f2(x_)
(3, 12)
换句话说,因为每个输入 pspec
可以零次或一次提及每个网格轴名称,而不必精确一次提及每个名称,我们可以说除了其输入中内置的 jnp.split
外,shard_map
还有一个至少逻辑上内置的 jnp.tile
(尽管根据参数的物理分片布局,可能不需要进行物理铺设)。要使用的铺设方式不唯一;我们也可以沿第一个轴进行铺设,并使用 pspec P(('j', 'i'), None)
。
可以在输入上进行物理数据移动,因为每个设备都需要有适当数据的副本。
通过 out_specs
控制每个由连接、块转置和使用 out_specs
反铺设组装的输出。
类似于输入端,out_specs
中的每个标识符通过名称将输出数组的一些轴与网格轴相关联,表示应如何将输出块(每个体函数应用的一个,或等效地每个物理设备一个)重新组装以形成最终输出值。例如,在上述 f1
和 f2
的例子中,out_specs
表明我们应该沿两个轴连接在一起形成最终输出,结果在两种情况下都是形状为 (12, 24)
的数组 y
。(如果体函数的输出形状,即输出块形状,对应的输出 pspec
描述的连接的秩过小,则出错。)
当一个网格轴名称在输出 pspec 中未被提及时,表示一个取消铺设:用户编写一个输出 pspec,其中未提及网格轴名称之一,他们保证输出块沿该网格轴是相等的,因此在输出中只使用一个沿该轴的块(而不是沿该网格轴连接所有块)。例如,使用与上述相同的网格:
x = jnp.array([[3.]]) z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))() print(z) # prints the same as jnp.tile(x, (4, 2)) z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))() print(z) # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,)) z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))() print(z) # prints the same as jnp.tile(x, (1, 1)), or just x
[[3\. 3.] [3\. 3.] [3\. 3.] [3\. 3.]] [[3.] [3.] [3.] [3.]] [[3.]]
闭合在数组值上的主体函数等效于将其作为具有相应输入 pspec 的增强传递。作为另一个示例,更接近于上述其他示例:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None)) def f3(x_block): return jax.lax.psum(x_block, 'j') x = jnp.arange(12 * 12).reshape(12, 12) y3 = f3(x) print(y3.shape)
(12, 6)
结果的第二轴大小为 6,输入的第二轴大小的一半。在这种情况下,通过在输出 pspec 中未提及网格轴名称 'j'
来表达取消铺设是安全的,因为集体 psum
确保每个输出块沿相应的网格轴是相等的。以下是两个更改输出 pspec 中提及的网格轴的示例:
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f4(x_block): return jax.lax.psum(x_block, 'i') x = jnp.arange(12 * 12).reshape(12, 12) y4 = f4(x) print(y4.shape) # (3,12) @partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None)) def f5(x_block): return jax.lax.psum(x_block, ('i', 'j')) y5 = f5(x) print(y5.shape) # (3,6)
(3, 12) (3, 6)
在物理方面,在输出 pspec 中未提及网格轴名称将使用沿该网格轴复制布局从输出设备缓冲区组装 Array
。
没有运行时检查,以确保输出块实际上沿网格轴是相等的,从而可以取消铺设,或者等效地说,相应的物理缓冲区具有相等的值,因此可以被解释为单个逻辑数组的复制布局。但是,我们可以提供一个静态检查机制,在所有潜在不正确的程序上引发错误。
因为 out_specs
可以零次或一次提及网格轴名称,并且可以以任何顺序提及,所以除了其输出中内置的 jnp.concatenate
外,shard_map
还包括 取消铺设 和 块转置。
输出上无论输出 pspec 如何,物理数据移动都是不可能的。相反,out_specs
只是编码如何将块输出组装成 Array
,或者物理上如何解释跨设备的缓冲区作为单个逻辑 Array
的物理布局。
API 规范
from jax.sharding import Mesh Specs = PyTree[PartitionSpec] def shard_map( f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, auto: collections.abc.Set[AxisName] = frozenset([]), check_rep: bool = True, ) -> Callable: ...
其中:
- 在
f
的主体中,像psum
这样的通信集合可以提及mesh
的轴名称; mesh
编码排列成数组并带有关联轴名称的设备,就像sharding.NamedSharding
一样;in_specs
和out_specs
是PartitionSpec
,可以用来从mesh
中仿射地提及轴名称,以表达输入和输出的切片/未连接和连接,分别对应于未提及名称的复制和取消铺设(断言-复制-因此-给我-一个-副本);auto
是对应于mesh
名称子集的可选轴名称,在主体中自动处理,如在调用者中,而不是手动处理;check_rep
是一个可选布尔值,指示静态检查out_specs
中是否存在任何复制错误,并且是否启用相关的自动微分优化(参见JEP)。
传递给f
的参数的形状与传递给shard_map
-of-f
的参数的形状具有相同的秩,f
的参数的形状从相应的shard_map
-of-f
的形状shape
和相应的PartitionSpec
spec
中粗略计算为tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))
。
集合教程
shard_map
不必是纯映射:函数应用可以通过集合与彼此通信,使用在mesh
参数中定义的轴名称。
请记住,shard_map
将函数映射到输入数据的分片或块,因此这样:
mesh = Mesh(jax.devices(), ('i',)) x = jnp.arange(16.) f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i')) y = f_shmapped(x)
计算相同的值,评估对相同参数值的f
的应用,如此参考函数:
def f_shmapped_ref(x): x_blocks = jnp.array_split(x, mesh.shape[0]) y_blocks = [f(x_blk) for x_blk in x_blocks] return jnp.concatenate(y_blocks)
我们将这些对不同参数分片的f
的应用称为函数实例。每个函数实例在不同的设备(或设备子集)上执行。
这些引用语义在f
中没有通信集合时有效。但是如果我们希望函数实例进行通信,即进行跨设备通信,该怎么办?也就是说,当f
包含一个集合时,引用语义是什么?假设f
只有一个集合,并且形式如下:
def f(x_blk): z_blk = f_part1(x_blk) u_blk = collective(z_blk, axis_name) v_blk = f_part2(x_blk, z_blk, u_blk) return v_blk
假设我们映射的唯一网格轴只有一个,并且axis_name
是其对应的名称。然后引用语义看起来更像是:
def f_shmapped_ref(x): x_blocks = jnp.array_split(x, mesh.shape[0]) z_blocks = [f_part1(x_blk) for x_blk in x_blocks] u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))] v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk in zip(x_blocks, z_blocks, u_blocks)] return jnp.concatenate(v_blocks)
注意,collective_ref
可能依赖于所有的z_blocks
。也就是说,虽然f_part1
和f_part2
独立地映射到块上,但是集合引入了跨块依赖。在物理上,这意味着跨设备的通信。确切地说,通信发生了什么,以及计算了什么值,取决于集合。
psum
最简单的集合可能是jax.lax.psum
,它沿着设备网格轴(或多个轴)计算全归约和。这里是一个玩具示例:
import jax import jax.numpy as jnp from jax import lax from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.experimental.shard_map import shard_map
mesh1d = Mesh(jax.devices()[:4], ('i',)) @partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None)) def f1(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') print('AFTER:\n', y_block) return y_block
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2]) y = f1(x) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [3 1 4 1] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [5 9 2 6] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [5 3 5 8] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [9 7 1 2] AFTER: On TFRT_CPU_0 at mesh coordinates (i,) = (0,): [22 20 12 17] On TFRT_CPU_1 at mesh coordinates (i,) = (1,): [22 20 12 17] On TFRT_CPU_2 at mesh coordinates (i,) = (2,): [22 20 12 17] On TFRT_CPU_3 at mesh coordinates (i,) = (3,): [22 20 12 17] FINAL RESULT: [22 20 12 17]
打印显示,每个函数应用都从其自己的参数值块x_block
开始。在psum
之后,每个函数应用都有相同的y_block
值,通过将应用的x_block
值求和而得到。
在计算中存在单个轴名称的情况下,我们可以说collective_ref
对于psum
的引用实现是:
def psum_ref(_, x_blocks): tot = sum(x_blocks) return [tot] * len(x_blocks)
还要注意,因为f1
返回y_block
,对'i'
进行psum
的结果,我们可以使用out_specs=P()
,这样调用者就可以得到单个逻辑副本的结果值,而不是平铺的结果。
当存在多个网格轴时,我们可以分别对每个轴执行psum
,或者同时对多个轴执行:
mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j')) @partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j')) def f2(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, 'i') print('AFTER:\n', y_block) return y_block y = f2(jnp.arange(16).reshape(4, 4)) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0): [[0 1] [4 5]] On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1): [[2 3] [6 7]] On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0): [[ 8 9] [12 13]] On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1): [[10 11] [14 15]] AFTER: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0): [[ 8 10] [16 18]] On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1): [[12 14] [20 22]] On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0): [[ 8 10] [16 18]] On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1): [[12 14] [20 22]] FINAL RESULT: [[ 8 10 12 14] [16 18 20 22]]
通过在网格轴 'i'
上应用 psum
,我们得到沿 'i'
轴相等的 y_block
值,但不沿 'j'
轴相等。(因此,我们可以使用 out_specs=P(None, 'j')
来获取沿该轴的单一逻辑结果。)
如果我们在两个轴上应用 psum
,则 y_block
值沿两个轴相等:
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None)) def f3(x_block): print('BEFORE:\n', x_block) y_block = jax.lax.psum(x_block, ('i', 'j')) print('AFTER:\n', y_block) return y_block y = f3(jnp.arange(16).reshape(4, 4)) print('FINAL RESULT:\n', y)
BEFORE: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0): [[0 1] [4 5]] On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1): [[2 3] [6 7]] On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0): [[ 8 9] [12 13]] On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1): [[10 11] [14 15]] AFTER: On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0): [[20 24] [36 40]] On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1): [[20 24] [36 40]] On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0): [[20 24] [36 40]] On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1): [[20 24] [36 40]] FINAL RESULT: [[20 24] [36 40]]
在机器学习中,我们经常使用 psum
来计算总损失,或者当我们在 shard_map
函数体内有一个 grad
时,计算总梯度。
接下来,我们将看到如何用其他基元实现 psum
,这些基元能够对其通信成本提供一些直观的理解。
JAX 中文文档(七)(2)https://developer.aliyun.com/article/1559697