JAX 中文文档(六)(2)https://developer.aliyun.com/article/1559682
分布式数组和自动并行化
原文:
jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
[外链图片转存中…(img-WY6D3kNB-1718950514657)]
本教程讨论了通过 jax.Array
实现的并行计算,这是 JAX v0.4.1 及更高版本中可用的统一数组对象模型。
import os import functools from typing import Optional import numpy as np import jax import jax.numpy as jnp
⚠️ 警告:此笔记本需要 8 个设备才能运行。
if len(jax.local_devices()) < 8: raise Exception("Notebook requires 8 devices to run")
简介和一个快速示例
通过阅读这本教程笔记本,您将了解 jax.Array
,一种用于表示数组的统一数据类型,即使物理存储跨越多个设备。您还将学习如何使用 jax.Array
与 jax.jit
结合,实现基于编译器的自动并行化。
在我们逐步思考之前,这里有一个快速示例。首先,我们将创建一个跨多个设备分片的 jax.Array
:
from jax.experimental import mesh_utils from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices: sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
# Create an array of random values: x = jax.random.normal(jax.random.key(0), (8192, 8192)) # and use jax.device_put to distribute it across devices: y = jax.device_put(x, sharding.reshape(4, 2)) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
接下来,我们将对其应用计算,并可视化结果值如何存储在多个设备上:
z = jnp.sin(y) jax.debug.visualize_array_sharding(z)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
jnp.sin
应用的评估已自动并行化,该应用跨存储输入值(和输出值)的设备:
# `x` is present on a single device %timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached 5 loops, best of 5: 9.69 ms per loop
# `y` is sharded across 8 devices. %timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
5 loops, best of 5: 1.86 ms per loop
现在让我们更详细地查看每个部分!
Sharding
描述了如何将数组值布局在跨设备的内存中。
Sharding 基础知识和 PositionalSharding
子类
要在多个设备上并行计算,我们首先必须在多个设备上布置输入数据。
在 JAX 中,Sharding
对象描述了分布式内存布局。它们可以与 jax.device_put
结合使用,生成具有分布式布局的值。
例如,这里是一个单设备 Sharding
的值:
import jax x = jax.random.normal(jax.random.key(0), (8192, 8192))
jax.debug.visualize_array_sharding(x)
┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0 │ │ │ │ │ │ │ │ │ └───────────────────────┘
在这里,我们使用 jax.debug.visualize_array_sharding
函数来展示内存中存储值 x
的位置。整个 x
存储在单个设备上,所以可视化效果相当无聊!
但是我们可以通过使用 jax.device_put
和 Sharding
对象将 x
分布在多个设备上。首先,我们使用 mesh_utils.create_device_mesh
制作一个 Devices
的 numpy.ndarray
,该函数考虑了硬件拓扑以确定 Device
的顺序:
from jax.experimental import mesh_utils devices = mesh_utils.create_device_mesh((8,))
然后,我们创建一个 PositionalSharding
并与 device_put
一起使用:
from jax.sharding import PositionalSharding sharding = PositionalSharding(devices) x = jax.device_put(x, sharding.reshape(8, 1)) jax.debug.visualize_array_sharding(x)
┌───────────────────────┐ │ TPU 0 │ ├───────────────────────┤ │ TPU 1 │ ├───────────────────────┤ │ TPU 2 │ ├───────────────────────┤ │ TPU 3 │ ├───────────────────────┤ │ TPU 6 │ ├───────────────────────┤ │ TPU 7 │ ├───────────────────────┤ │ TPU 4 │ ├───────────────────────┤ │ TPU 5 │ └───────────────────────┘
这里的 sharding
是一个 PositionalSharding
,它的作用类似于一个具有设备集合作为元素的数组:
sharding
PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])
这里的设备编号不是按数字顺序排列的,因为网格反映了设备的基础环形拓扑结构。
通过编写 PositionalSharding(ndarray_of_devices)
,我们确定了设备顺序和初始形状。然后我们可以对其进行重新形状化:
sharding.reshape(8, 1)
PositionalSharding([[{TPU 0}] [{TPU 1}] [{TPU 2}] [{TPU 3}] [{TPU 6}] [{TPU 7}] [{TPU 4}] [{TPU 5}]])
sharding.reshape(4, 2)
PositionalSharding([[{TPU 0} {TPU 1}] [{TPU 2} {TPU 3}] [{TPU 6} {TPU 7}] [{TPU 4} {TPU 5}]])
要使用device_put
与数据数组x
,我们可以将sharding
重新形状为与x.shape
同余的形状,这意味着具有与x.shape
相同长度的形状,并且其中每个元素均匀地分割对应x.shape
的元素:
def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool: return (len(x_shape) == len(sharding_shape) and all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))
例如,我们可以将sharding
重新形状为(4, 2)
,然后在device_put
中使用它:
sharding = sharding.reshape(4, 2) print(sharding)
PositionalSharding([[{TPU 0} {TPU 1}] [{TPU 2} {TPU 3}] [{TPU 6} {TPU 7}] [{TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里的y
代表与x
相同的值,但其片段(即切片)存储在不同设备的内存中。
不同的PositionalSharding
形状会导致结果的不同分布布局(即分片):
sharding = sharding.reshape(1, 8) print(sharding)
PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])
y = jax.device_put(x, sharding) jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
在某些情况下,我们不只是想将x
的每个切片存储在单个设备的内存中;我们可能希望在多个设备的内存中复制一些切片,即在多个设备的内存中存储切片的值。
使用PositionalSharding
,我们可以通过调用 reducer 方法replicate
来表达复制:
sharding = sharding.reshape(4, 2) print(sharding.replicate(axis=0, keepdims=True))
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True)) jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
这里的可视化显示了x
沿其第二维以两种方式分片(而不沿第一维分片),每个片段都复制了四种方式(即存储在四个设备内存中)。
replicate
方法类似于熟悉的 NumPy 数组缩减方法,如.sum()
和.prod()
。它沿着一个轴执行集合并操作。因此,如果sharding
的形状为(4, 2)
,那么sharding.replicate(0, keepdims=True)
的形状为(1, 2)
,sharding.replicate(1, keepdims=True)
的形状为(4, 1)
。与 NumPy 方法不同,keepdims=True
实际上是默认的,因此减少的轴不会被压缩:
print(sharding.replicate(0).shape) print(sharding.replicate(1).shape)
(1, 2) (4, 1)
y = jax.device_put(x, sharding.replicate(1)) jax.debug.visualize_array_sharding(y)
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
NamedSharding
提供了一种使用名称表达分片的方式。
到目前为止,我们已经使用了PositionalSharding
,但还有其他表达分片的替代方法。实际上,Sharding
是一个接口,任何实现该接口的类都可以与device_put
等函数一起使用。
另一种方便的表达分片的方法是使用NamedSharding
:
from jax.sharding import Mesh from jax.sharding import PartitionSpec from jax.sharding import NamedSharding from jax.experimental import mesh_utils P = PartitionSpec devices = mesh_utils.create_device_mesh((4, 2)) mesh = Mesh(devices, axis_names=('a', 'b')) y = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
我们可以定义一个辅助函数使事情更简单:
devices = mesh_utils.create_device_mesh((4, 2)) default_mesh = Mesh(devices, axis_names=('a', 'b')) def mesh_sharding( pspec: PartitionSpec, mesh: Optional[Mesh] = None, ) -> NamedSharding: if mesh is None: mesh = default_mesh return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P('a', 'b'))) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
在这里,我们使用P('a', 'b')
来表达x
的第一和第二轴应该分片到设备网格轴'a'
和'b'
上。我们可以轻松切换到P('b', 'a')
以在不同设备上分片x
的轴:
y = jax.device_put(x, mesh_sharding(P('b', 'a'))) jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
# This `None` means that `x` is not sharded on its second dimension, # and since the Mesh axis name 'b' is not mentioned, shards are # replicated across it. y = jax.device_put(x, mesh_sharding(P('a', None))) jax.debug.visualize_array_sharding(y)
┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘
这里,因为P('a', None)
没有提及Mesh
轴名'b'
,我们在轴'b'
上得到了复制。这里的None
只是一个占位符,用于与值x
的第二轴对齐,而不表示在任何网格轴上进行分片。(简写方式是,尾部的None
可以省略,因此P('a', None)
的意思与P('a')
相同。但是明确说明并不会有害!)
要仅在x
的第二轴上进行分片,我们可以在PartitionSpec
中使用None
占位符。
y = jax.device_put(x, mesh_sharding(P(None, 'b'))) jax.debug.visualize_array_sharding(y)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
y = jax.device_put(x, mesh_sharding(P(None, 'a'))) jax.debug.visualize_array_sharding(y)
┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
对于固定的网格,我们甚至可以将x
的一个逻辑轴分割到多个设备网格轴上:
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None))) jax.debug.visualize_array_sharding(y)
┌───────────────────────┐ │ TPU 0 │ ├───────────────────────┤ │ TPU 1 │ ├───────────────────────┤ │ TPU 2 │ ├───────────────────────┤ │ TPU 3 │ ├───────────────────────┤ │ TPU 6 │ ├───────────────────────┤ │ TPU 7 │ ├───────────────────────┤ │ TPU 4 │ ├───────────────────────┤ │ TPU 5 │ └───────────────────────┘
使用NamedSharding
可以轻松定义一次设备网格并为其轴命名,然后只需在需要时在每个device_put
的PartitionSpec
中引用这些名称。
JAX 中文文档(六)(4)https://developer.aliyun.com/article/1559684