JAX 中文文档(七)(1)

简介: JAX 中文文档(七)


原文:jax.readthedocs.io/en/latest/

使用 shard_map 的 SPMD 多设备并行性

原文:jax.readthedocs.io/en/latest/notebooks/shard_map.html

shard_map 是一种单程序多数据(SPMD)多设备并行性 API,用于在数据分片上映射函数。映射的函数应用或实例通过显式的集合通信操作进行通信。

shard_map 是与 jit 内置的自动编译器并行化互补且可组合的。使用 jit,你编写的代码就像为单个设备编写的一样,并且编译器可以自动将计算分区到多个设备上,在幕后生成每个设备的代码和通信集合。使用 shard_map,你可以控制自己的分区代码和显式集合。或者你可以同时进行一些操作:在组设备中手动控制同时保留组内设备分区给编译器。这两种方法可以根据需要混合、匹配和组合。

如果您熟悉 pmap,可以将 shard_map 视为其演进。它更具表现力、性能和与其他 JAX API 可组合。它甚至可以急切地工作,更易于调试!(更多信息,请参阅pmap 的详细比较。

通过阅读本教程,您将学习如何使用 shard_map 来完全控制您的多设备代码。您将详细了解它如何与 jax.jit 的自动并行化和 jax.grad 的自动微分结合使用。我们还将给出一些神经网络并行化策略的基本示例。

我们假设本教程在具有八个设备的环境中运行:

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices 

所以,让我们来看一个 shard_map 吧!

不多说了,这里是一个玩具例子:

from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map 
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 *  4.).reshape(16, 4)
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P('y', None)),
         out_specs=P('x', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 4]
  c_partialsum = jnp.dot(a_block, b_block)
  c_block = jax.lax.psum(c_partialsum, 'y')
  # c_block: f32[2, 4]
  return c_block
c = matmul_basic(a, b)   # c: f32[8, 4] 

这个函数通过执行本地块矩阵乘法,然后进行集合求和操作来并行计算矩阵乘积。我们可以检查结果是否正确:

from jax.tree_util import tree_map, tree_all
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
allclose(c, jnp.dot(a, b)) 
True 

结果沿其行被分片:

jax.debug.visualize_array_sharding(c) 
CPU 0,1 
 CPU 2,3 
 CPU 4,5 
 CPU 6,7 

在高层次上,shard_map 在某种程度上类似于 vmappmap,因为我们在数组数据的部分上映射函数,但请注意

  • shard_map 将输入切片成块(输出由连接结果块形成),保持秩不变,而 vmap 则通过映射掉一个轴来减少秩;
  • mesh 参数允许我们控制计算和结果的精确设备放置;
  • 我们同时映射多个数据轴,并设置多个轴名称以进行集合操作(这里有 'x''y');
  • 因为我们还没有使用 jax.jit,一切都是急切地评估的,我们甚至可以打印中间值以进行调试。

上述代码执行与此 jax.jit 自动并行化代码相同的计算:

from jax.sharding import NamedSharding
a = jax.device_put(a, NamedSharding(mesh, P('x', 'y')))
b = jax.device_put(b, NamedSharding(mesh, P('y', None)))
@jax.jit
def matmul_reference(a, b):
  c = jnp.dot(a, b)
  return jax.lax.with_sharding_constraint(c, NamedSharding(mesh, P('x', None)))
c_ref = matmul_reference(a, b)
allclose(c_ref, jnp.dot(a, b)) 
True 

我们可以将 shard_map 视为根据其 meshin_specs 参数在其输入上执行 device_putwith_sharding_constraint,因此 matmul_basic 操作的块与 matmul_reference 中的相同:

print('a blocks:'); jax.debug.visualize_array_sharding(a)
print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c) 
a blocks:
b blocks:
c blocks: 
CPU 0CPU 1 
 CPU 2CPU 3 
 CPU 4CPU 5 
 CPU 6CPU 7 
CPU 0,2,4,6
CPU 1,3,5,7
CPU 0,1 
 CPU 2,3 
 CPU 4,5 
 CPU 6,7 

放慢速度,从基础开始!

降维与保持秩的映射

我们可以将 vmappmap 看作是沿轴(例如将 2D 矩阵解包成其 1D 行)对每个数组输入应用其主体函数,然后将结果堆叠在一起,至少在不涉及集合操作时是这样的:

def check_vmap(f, xs):
  ans = jax.vmap(f, in_axes=(0,), out_axes=0)(xs)
  expected = jnp.stack([f(x) for x in xs])  # vmap reference semantics
  print(allclose(ans, expected))
check_vmap(lambda x: x @ x, jnp.arange(12).reshape(4, 3)) 
True 

例如,如果 xs 的形状为 f32[8,5],那么每个 x 的形状将为 f32[5],如果每个 f(x) 的形状为 f32[3,7],那么最终堆叠的结果 vmap(f)(xs) 的形状将为 f32[8,3,7]。也就是说,函数 f 的每个应用都以比 vmap(f) 对应参数少一个轴的输入作为参数。我们可以说这些是降维映射,输入/输出的解包/堆叠。

函数 f 的逻辑应用数量,或称为 f实例,取决于被映射输入轴的大小:例如,如果我们映射一个大小为 8 的输入轴,语义上我们得到函数的 8 个逻辑应用。

相比之下,shard_map 并没有这种降维行为。相反,我们可以将其视为沿输入轴切片(或“取消连接”)成块,应用主体函数,然后将结果再次连接在一起(同样是在不涉及集合操作时):

import numpy as np
devices = np.array(jax.devices()[:4])
mesh = Mesh(devices, ('i',))  # mesh.shape['i'] = 4
def check_shmap(f, y):
  ans = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))(y)
  expected = jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, mesh.shape['i'])])
  print(allclose(ans, expected))
check_shmap(lambda x: x.T @ x, jnp.arange(32).reshape(8, 4)) 
True 

回想一下,jnp.split 将其输入切片为相同大小的块,以便如果在上述示例中 y 的形状为 f32[8,5],那么每个 y_blk 的形状将为 f32[2,5],如果每个 f(y_blk) 的形状为 f32[3,7],那么最终连接的结果 shard_map(f, ...)(y) 的形状将为 f32[12,7]。因此,shard_map 对其输入进行保持秩的映射,输入/输出的取消连接/连接。

函数 f 的逻辑应用数量由网格大小决定,而不是任何输入轴的大小:例如,如果我们有总大小为 4 的网格(即在 4 个设备上),那么语义上我们得到函数的 4 个逻辑应用,对应于物理计算这些函数的 4 个设备。

控制每个输入如何分割(取消连接)并与 in_specs 平铺

每个 in_specs 通过 PartitionSpec  标识某些对应输入数组轴的网格轴名称,表示如何将该输入分割(或解串联)为应用体函数的块。该标识确定了碎片大小;当输入轴与网格轴标识为同一时,输入沿该逻辑轴分割(解串联)为数目等于相应网格轴大小的片段。(如果相应的网格轴大小不能整除输入数组轴大小,则出错。)如果输入的  pspec 未提及网格轴名称,则在该网格轴上没有分割。例如:

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))
@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)  # prints (3, 12)
  return x_block
x1 = jnp.arange(12 * 12).reshape(12, 12)
y = f1(x1) 
(3, 12) 

在这里,因为输入 pspec 未提及网格轴名称 'j',因此没有输入数组轴沿该网格轴进行分割;类似地,因为输入数组的第二轴没有标识(因此没有沿任何网格轴分割),f1 的应用获得了沿该轴的完整视图。

当输入 pspec 中未提及网格轴时,我们可以始终重写为一个效率较低的程序,其中所有网格轴都被提及,但调用者执行 jnp.tile,例如:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block
x = jnp.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.shape['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_) 
(3, 12) 

换句话说,因为每个输入 pspec 可以零次或一次提及每个网格轴名称,而不必精确一次提及每个名称,我们可以说除了其输入中内置的 jnp.split 外,shard_map 还有一个至少逻辑上内置的 jnp.tile(尽管根据参数的物理分片布局,可能不需要进行物理铺设)。要使用的铺设方式不唯一;我们也可以沿第一个轴进行铺设,并使用 pspec P(('j', 'i'), None)

可以在输入上进行物理数据移动,因为每个设备都需要有适当数据的副本。

通过 out_specs 控制每个由连接、块转置和使用 out_specs 反铺设组装的输出。

类似于输入端,out_specs 中的每个标识符通过名称将输出数组的一些轴与网格轴相关联,表示应如何将输出块(每个体函数应用的一个,或等效地每个物理设备一个)重新组装以形成最终输出值。例如,在上述 f1f2 的例子中,out_specs 表明我们应该沿两个轴连接在一起形成最终输出,结果在两种情况下都是形状为 (12, 24) 的数组 y。(如果体函数的输出形状,即输出块形状,对应的输出 pspec 描述的连接的秩过小,则出错。)

当一个网格轴名称在输出 pspec 中未被提及时,表示一个取消铺设:用户编写一个输出 pspec,其中未提及网格轴名称之一,他们保证输出块沿该网格轴是相等的,因此在输出中只使用一个沿该轴的块(而不是沿该网格轴连接所有块)。例如,使用与上述相同的网格:

x = jnp.array([[3.]])
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))
z = shard_map(lambda: x, mesh=mesh, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x 
[[3\. 3.]
 [3\. 3.]
 [3\. 3.]
 [3\. 3.]]
[[3.]
 [3.]
 [3.]
 [3.]]
[[3.]] 

闭合在数组值上的主体函数等效于将其作为具有相应输入 pspec 的增强传递。作为另一个示例,更接近于上述其他示例:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')
x = jnp.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape) 
(12, 6) 

结果的第二轴大小为 6,输入的第二轴大小的一半。在这种情况下,通过在输出 pspec 中未提及网格轴名称 'j' 来表达取消铺设是安全的,因为集体 psum 确保每个输出块沿相应的网格轴是相等的。以下是两个更改输出 pspec 中提及的网格轴的示例:

@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')
x = jnp.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)
@partial(shard_map, mesh=mesh, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))
y5 = f5(x)
print(y5.shape)  # (3,6) 
(3, 12)
(3, 6) 

在物理方面,在输出 pspec 中未提及网格轴名称将使用沿该网格轴复制布局从输出设备缓冲区组装 Array

没有运行时检查,以确保输出块实际上沿网格轴是相等的,从而可以取消铺设,或者等效地说,相应的物理缓冲区具有相等的值,因此可以被解释为单个逻辑数组的复制布局。但是,我们可以提供一个静态检查机制,在所有潜在不正确的程序上引发错误。

因为 out_specs 可以零次或一次提及网格轴名称,并且可以以任何顺序提及,所以除了其输出中内置的 jnp.concatenate 外,shard_map 还包括 取消铺设块转置

输出上无论输出 pspec 如何,物理数据移动都是不可能的。相反,out_specs 只是编码如何将块输出组装成 Array,或者物理上如何解释跨设备的缓冲区作为单个逻辑 Array 的物理布局。

API 规范

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]
def shard_map(
    f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs,
    auto: collections.abc.Set[AxisName] = frozenset([]),
    check_rep: bool = True,
) -> Callable:
  ... 

其中:

  • f 的主体中,像 psum 这样的通信集合可以提及 mesh 的轴名称;
  • mesh 编码排列成数组并带有关联轴名称的设备,就像 sharding.NamedSharding 一样;
  • in_specsout_specsPartitionSpec,可以用来从 mesh 中仿射地提及轴名称,以表达输入和输出的切片/未连接和连接,分别对应于未提及名称的复制和取消铺设(断言-复制-因此-给我-一个-副本);
  • auto 是对应于 mesh 名称子集的可选轴名称,在主体中自动处理,如在调用者中,而不是手动处理;
  • check_rep是一个可选布尔值,指示静态检查out_specs中是否存在任何复制错误,并且是否启用相关的自动微分优化(参见JEP)。

传递给f的参数的形状与传递给shard_map-of-f的参数的形状具有相同的秩,f的参数的形状从相应的shard_map-of-f的形状shape和相应的PartitionSpec spec中粗略计算为tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

集合教程

shard_map不必是纯映射:函数应用可以通过集合与彼此通信,使用在mesh参数中定义的轴名称。

请记住,shard_map将函数映射到输入数据的分片或块,因此这样:

mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(16.)
f_shmapped = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
y = f_shmapped(x) 

计算相同的值,评估对相同参数值的f的应用,如此参考函数:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  y_blocks = [f(x_blk) for x_blk in x_blocks]
  return jnp.concatenate(y_blocks) 

我们将这些对不同参数分片的f的应用称为函数实例。每个函数实例在不同的设备(或设备子集)上执行。

这些引用语义在f中没有通信集合时有效。但是如果我们希望函数实例进行通信,即进行跨设备通信,该怎么办?也就是说,当f包含一个集合时,引用语义是什么?假设f只有一个集合,并且形式如下:

def f(x_blk):
  z_blk = f_part1(x_blk)
  u_blk = collective(z_blk, axis_name)
  v_blk = f_part2(x_blk, z_blk, u_blk)
  return v_blk 

假设我们映射的唯一网格轴只有一个,并且axis_name是其对应的名称。然后引用语义看起来更像是:

def f_shmapped_ref(x):
  x_blocks = jnp.array_split(x, mesh.shape[0])
  z_blocks = [f_part1(x_blk) for x_blk in x_blocks]
  u_blocks = [collective_ref(i, z_blocks) for i in range(len(z_blocks))]
  v_blocks = [f_part2(x_blk, z_blk, u_blk) for x_blk, z_blk, u_blk
              in zip(x_blocks, z_blocks, u_blocks)]
  return jnp.concatenate(v_blocks) 

注意,collective_ref可能依赖于所有的z_blocks。也就是说,虽然f_part1f_part2独立地映射到块上,但是集合引入了跨块依赖。在物理上,这意味着跨设备的通信。确切地说,通信发生了什么,以及计算了什么值,取决于集合。

psum

最简单的集合可能是jax.lax.psum,它沿着设备网格轴(或多个轴)计算全归约和。这里是一个玩具示例:

import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map 
mesh1d = Mesh(jax.devices()[:4], ('i',))
@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))
def f1(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block 
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])
y = f1(x)
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[3 1 4 1]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[5 9 2 6]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[5 3 5 8]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[9 7 1 2]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i,) = (0,):
[22 20 12 17]
On TFRT_CPU_1 at mesh coordinates (i,) = (1,):
[22 20 12 17]
On TFRT_CPU_2 at mesh coordinates (i,) = (2,):
[22 20 12 17]
On TFRT_CPU_3 at mesh coordinates (i,) = (3,):
[22 20 12 17]
FINAL RESULT:
 [22 20 12 17] 

打印显示,每个函数应用都从其自己的参数值块x_block开始。在psum之后,每个函数应用都有相同的y_block值,通过将应用的x_block值求和而得到。

在计算中存在单个轴名称的情况下,我们可以说collective_ref对于psum的引用实现是:

def psum_ref(_, x_blocks):
  tot = sum(x_blocks)
  return [tot] * len(x_blocks) 

还要注意,因为f1返回y_block,对'i'进行psum的结果,我们可以使用out_specs=P(),这样调用者就可以得到单个逻辑副本的结果值,而不是平铺的结果。

当存在多个网格轴时,我们可以分别对每个轴执行psum,或者同时对多个轴执行:

mesh2d = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('i', 'j'))
@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f2(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, 'i')
  print('AFTER:\n', y_block)
  return y_block
y = f2(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[ 8 10]
 [16 18]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[12 14]
 [20 22]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8 10]
 [16 18]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[12 14]
 [20 22]]
FINAL RESULT:
 [[ 8 10 12 14]
 [16 18 20 22]] 

通过在网格轴 'i' 上应用 psum,我们得到沿 'i' 轴相等的 y_block 值,但不沿 'j' 轴相等。(因此,我们可以使用 out_specs=P(None, 'j') 来获取沿该轴的单一逻辑结果。)

如果我们在两个轴上应用 psum,则 y_block 值沿两个轴相等:

@partial(shard_map, mesh=mesh2d, in_specs=P('i', 'j'), out_specs=P(None, None))
def f3(x_block):
  print('BEFORE:\n', x_block)
  y_block = jax.lax.psum(x_block, ('i', 'j'))
  print('AFTER:\n', y_block)
  return y_block
y = f3(jnp.arange(16).reshape(4, 4))
print('FINAL RESULT:\n', y) 
BEFORE:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[0 1]
 [4 5]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[2 3]
 [6 7]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[ 8  9]
 [12 13]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[10 11]
 [14 15]]
AFTER:
 On TFRT_CPU_0 at mesh coordinates (i, j,) = (0, 0):
[[20 24]
 [36 40]]
On TFRT_CPU_1 at mesh coordinates (i, j,) = (0, 1):
[[20 24]
 [36 40]]
On TFRT_CPU_2 at mesh coordinates (i, j,) = (1, 0):
[[20 24]
 [36 40]]
On TFRT_CPU_3 at mesh coordinates (i, j,) = (1, 1):
[[20 24]
 [36 40]]
FINAL RESULT:
 [[20 24]
 [36 40]] 

在机器学习中,我们经常使用 psum 来计算总损失,或者当我们在 shard_map 函数体内有一个 grad 时,计算总梯度。

接下来,我们将看到如何用其他基元实现 psum,这些基元能够对其通信成本提供一些直观的理解。


JAX 中文文档(七)(2)https://developer.aliyun.com/article/1559697

相关文章
|
3月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
33 1
|
3月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
22 1
|
3月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
26 0
JAX 中文文档(六)(1)
|
3月前
|
测试技术 TensorFlow 算法框架/工具
JAX 中文文档(五)(2)
JAX 中文文档(五)
35 0
|
3月前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
54 0
|
3月前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
26 0
|
3月前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
21 0
|
3月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
34 0
|
3月前
|
缓存 PyTorch API
JAX 中文文档(一)(3)
JAX 中文文档(一)
37 0
|
3月前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
18 0