JAX 中文文档(六)(3)

简介: JAX 中文文档(六)

JAX 中文文档(六)(2)https://developer.aliyun.com/article/1559682


分布式数组和自动并行化

原文:jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

[外链图片转存中…(img-WY6D3kNB-1718950514657)]

本教程讨论了通过 jax.Array 实现的并行计算,这是 JAX v0.4.1 及更高版本中可用的统一数组对象模型。

import os
import functools
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp 

⚠️ 警告:此笔记本需要 8 个设备才能运行。

if len(jax.local_devices()) < 8:
  raise Exception("Notebook requires 8 devices to run") 

简介和一个快速示例

通过阅读这本教程笔记本,您将了解 jax.Array,一种用于表示数组的统一数据类型,即使物理存储跨越多个设备。您还将学习如何使用 jax.Arrayjax.jit 结合,实现基于编译器的自动并行化。

在我们逐步思考之前,这里有一个快速示例。首先,我们将创建一个跨多个设备分片的 jax.Array

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding 
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(4, 2))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

接下来,我们将对其应用计算,并可视化结果值如何存储在多个设备上:

z = jnp.sin(y)
jax.debug.visualize_array_sharding(z) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

jnp.sin 应用的评估已自动并行化,该应用跨存储输入值(和输出值)的设备:

# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready() 
The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached 
5 loops, best of 5: 9.69 ms per loop 
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready() 
5 loops, best of 5: 1.86 ms per loop 

现在让我们更详细地查看每个部分!

Sharding 描述了如何将数组值布局在跨设备的内存中。

Sharding 基础知识和 PositionalSharding 子类

要在多个设备上并行计算,我们首先必须在多个设备上布置输入数据。

在 JAX 中,Sharding 对象描述了分布式内存布局。它们可以与 jax.device_put 结合使用,生成具有分布式布局的值。

例如,这里是一个单设备 Sharding 的值:

import jax
x = jax.random.normal(jax.random.key(0), (8192, 8192)) 
jax.debug.visualize_array_sharding(x) 
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 

在这里,我们使用 jax.debug.visualize_array_sharding 函数来展示内存中存储值 x 的位置。整个 x 存储在单个设备上,所以可视化效果相当无聊!

但是我们可以通过使用 jax.device_putSharding 对象将 x 分布在多个设备上。首先,我们使用 mesh_utils.create_device_mesh 制作一个 Devicesnumpy.ndarray,该函数考虑了硬件拓扑以确定 Device 的顺序:

from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,)) 

然后,我们创建一个 PositionalSharding 并与 device_put 一起使用:

from jax.sharding import PositionalSharding
sharding = PositionalSharding(devices)
x = jax.device_put(x, sharding.reshape(8, 1))
jax.debug.visualize_array_sharding(x) 
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘ 

这里的 sharding 是一个 PositionalSharding,它的作用类似于一个具有设备集合作为元素的数组:

sharding 
PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]) 

这里的设备编号不是按数字顺序排列的,因为网格反映了设备的基础环形拓扑结构。

通过编写 PositionalSharding(ndarray_of_devices),我们确定了设备顺序和初始形状。然后我们可以对其进行重新形状化:

sharding.reshape(8, 1) 
PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]]) 
sharding.reshape(4, 2) 
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]]) 

要使用device_put与数据数组x,我们可以将sharding重新形状为与x.shape同余的形状,这意味着具有与x.shape相同长度的形状,并且其中每个元素均匀地分割对应x.shape的元素:

def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
  return (len(x_shape) == len(sharding_shape) and
          all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape))) 

例如,我们可以将sharding重新形状为(4, 2),然后在device_put中使用它:

sharding = sharding.reshape(4, 2)
print(sharding) 
PositionalSharding([[{TPU 0} {TPU 1}]
                    [{TPU 2} {TPU 3}]
                    [{TPU 6} {TPU 7}]
                    [{TPU 4} {TPU 5}]]) 
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里的y代表与x相同的,但其片段(即切片)存储在不同设备的内存中。

不同的PositionalSharding形状会导致结果的不同分布布局(即分片):

sharding = sharding.reshape(1, 8)
print(sharding) 
PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]]) 
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

在某些情况下,我们不只是想将x的每个切片存储在单个设备的内存中;我们可能希望在多个设备的内存中复制一些切片,即在多个设备的内存中存储切片的值。

使用PositionalSharding,我们可以通过调用 reducer 方法replicate来表达复制:

sharding = sharding.reshape(4, 2)
print(sharding.replicate(axis=0, keepdims=True)) 
PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]]) 
y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
jax.debug.visualize_array_sharding(y) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 

这里的可视化显示了x沿其第二维以两种方式分片(而不沿第一维分片),每个片段都复制了四种方式(即存储在四个设备内存中)。

replicate方法类似于熟悉的 NumPy 数组缩减方法,如.sum().prod()。它沿着一个轴执行集合并操作。因此,如果sharding的形状为(4, 2),那么sharding.replicate(0, keepdims=True)的形状为(1, 2)sharding.replicate(1, keepdims=True)的形状为(4, 1)。与 NumPy 方法不同,keepdims=True实际上是默认的,因此减少的轴不会被压缩:

print(sharding.replicate(0).shape)
print(sharding.replicate(1).shape) 
(1, 2)
(4, 1) 
y = jax.device_put(x, sharding.replicate(1))
jax.debug.visualize_array_sharding(y) 
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘ 

NamedSharding提供了一种使用名称表达分片的方式。

到目前为止,我们已经使用了PositionalSharding,但还有其他表达分片的替代方法。实际上,Sharding是一个接口,任何实现该接口的类都可以与device_put等函数一起使用。

另一种方便的表达分片的方法是使用NamedSharding

from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

我们可以定义一个辅助函数使事情更简单:

devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
    pspec: PartitionSpec, mesh: Optional[Mesh] = None,
  ) -> NamedSharding:
  if mesh is None:
    mesh = default_mesh
  return NamedSharding(mesh, pspec) 
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

在这里,我们使用P('a', 'b')来表达x的第一和第二轴应该分片到设备网格轴'a''b'上。我们可以轻松切换到P('b', 'a')以在不同设备上分片x的轴:

y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
y = jax.device_put(x, mesh_sharding(P('a', None)))
jax.debug.visualize_array_sharding(y) 
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘ 

这里,因为P('a', None)没有提及Mesh轴名'b',我们在轴'b'上得到了复制。这里的None只是一个占位符,用于与值x的第二轴对齐,而不表示在任何网格轴上进行分片。(简写方式是,尾部的None可以省略,因此P('a', None)的意思与P('a')相同。但是明确说明并不会有害!)

要仅在x的第二轴上进行分片,我们可以在PartitionSpec中使用None占位符。

y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y) 
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 

对于固定的网格,我们甚至可以将x的一个逻辑轴分割到多个设备网格轴上:

y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
┌───────────────────────┐
│         TPU 0         │
├───────────────────────┤
│         TPU 1         │
├───────────────────────┤
│         TPU 2         │
├───────────────────────┤
│         TPU 3         │
├───────────────────────┤
│         TPU 6         │
├───────────────────────┤
│         TPU 7         │
├───────────────────────┤
│         TPU 4         │
├───────────────────────┤
│         TPU 5         │
└───────────────────────┘ 

使用NamedSharding可以轻松定义一次设备网格并为其轴命名,然后只需在需要时在每个device_putPartitionSpec中引用这些名称。


JAX 中文文档(六)(4)https://developer.aliyun.com/article/1559684

相关文章
|
3天前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
9 1
|
3天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
11 1
|
3天前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
12 0
JAX 中文文档(六)(1)
|
3天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
3天前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
9 0
|
3天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
10 0
|
3天前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
10 0
|
3天前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
8 0
|
3天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0
|
3天前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
8 0