JAX 中文文档(六)(4)

简介: JAX 中文文档(六)

JAX 中文文档(六)(3)https://developer.aliyun.com/article/1559683


计算遵循数据分片并自动并行化

使用分片输入数据,编译器可以给我们并行计算。特别是,用 jax.jit 装饰的函数可以在分片数组上操作,而无需将数据复制到单个设备上。相反,计算遵循分片:基于输入数据的分片,编译器决定中间结果和输出值的分片,并并行评估它们,必要时甚至插入通信操作。

例如,最简单的计算是逐元素的:

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
x = jax.device_put(x, sharding.reshape(4, 2))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y) 
input sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
output sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里对于逐元素操作 jnp.sin,编译器选择了输出分片与输入相同。此外,编译器自动并行化计算,因此每个设备都可以并行计算其输出片段。

换句话说,即使我们将 jnp.sin 的计算写成单台机器执行,编译器也会为我们拆分计算并在多个设备上执行。

我们不仅可以对逐元素操作执行相同操作。考虑使用分片输入的矩阵乘法:

y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w) 
lhs sharding:
┌───────────────────────┐
│        TPU 0,1        │
├───────────────────────┤
│        TPU 2,3        │
├───────────────────────┤
│        TPU 6,7        │
├───────────────────────┤
│        TPU 4,5        │
└───────────────────────┘
rhs sharding:
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
out sharding:
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

这里编译器选择了输出分片,以便最大化并行计算:无需通信,每个设备已经具有计算其输出分片所需的输入分片。

我们如何确保它实际上是并行运行的?我们可以进行简单的时间实验:

x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single) 
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│         TPU 0         │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 
np.allclose(jnp.dot(x_single, x_single),
            jnp.dot(y, z)) 
True 
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready() 
5 loops, best of 5: 19.3 ms per loop 
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready() 
5 loops, best of 5: 3.25 ms per loop 

即使复制一个分片的 Array,也会产生具有输入分片的结果:

w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘ 

因此,当我们使用 jax.device_put 明确分片数据并对该数据应用函数时,编译器会尝试并行化计算并决定输出分片。这种对分片数据的策略是JAX 遵循显式设备放置策略的泛化

当明确分片不一致时,JAX 会报错

但是如果计算的两个参数在不同的设备组上明确放置,或者设备顺序不兼容,会发生错误:

import textwrap
from termcolor import colored
def print_exception(e):
  name = colored(f'{type(e).__name__}', 'red')
  print(textwrap.fill(f'{name}: {str(e)}')) 
sharding1 = PositionalSharding(jax.devices()[:4])
sharding2 = PositionalSharding(jax.devices()[4:])
y = jax.device_put(x, sharding1.reshape(2, 2))
z = jax.device_put(x, sharding2.reshape(2, 2))
try: y + z
except ValueError as e: print_exception(e) 
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3] on platform TPU and
another array's device ids [4, 5, 6, 7] on platform TPU 
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = PositionalSharding(devices)
sharding2 = PositionalSharding(permuted_devices)
y = jax.device_put(x, sharding1.reshape(4, 2))
z = jax.device_put(x, sharding2.reshape(4, 2))
try: y + z
except ValueError as e: print_exception(e) 
ValueError: Devices of all `Array` inputs and outputs should
be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform
TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on
platform TPU 

我们说通过 jax.device_put 明确放置或分片的数组已经锁定在它们的设备上,因此不会自动移动。请查看 设备放置常见问题解答 获取更多信息。

当数组没有使用 jax.device_put 明确放置或分片时,它们会放置在默认设备上并未锁定。与已锁定数组不同,未锁定数组可以自动移动和重新分片:也就是说,未锁定数组可以作为计算的参数,即使其他参数明确放置在不同的设备上。

例如,jnp.zerosjnp.arangejnp.array 的输出都是未锁定的:

y = jax.device_put(x, sharding1.reshape(4, 2))
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!') 
no error! 

限制在 jit 代码中的中间片段

虽然编译器将尝试决定函数的中间值和输出应如何分片,但我们还可以使用 jax.lax.with_sharding_constraint 来给它提供提示。使用 jax.lax.with_sharding_constraint 类似于 jax.device_put,不同之处在于我们在分阶段函数(即 jit 装饰的函数)内部使用它:

sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) 
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 2)) 
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
  return y 
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────┬───────┬───────┬───────┐
│       │       │       │       │
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │
│       │       │       │       │
│       │       │       │       │
├───────┼───────┼───────┼───────┤
│       │       │       │       │
│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │
│       │       │       │       │
│       │       │       │       │
└───────┴───────┴───────┴───────┘ 
@jax.jit
def f(x):
  x = x + 1
  y = jax.lax.with_sharding_constraint(x, sharding.replicate())
  return y 
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y) 
┌──────────┬──────────┐
│  TPU 0   │  TPU 1   │
├──────────┼──────────┤
│  TPU 2   │  TPU 3   │
├──────────┼──────────┤
│  TPU 6   │  TPU 7   │
├──────────┼──────────┤
│  TPU 4   │  TPU 5   │
└──────────┴──────────┘
┌───────────────────────┐
│                       │
│                       │
│                       │
│                       │
│  TPU 0,1,2,3,4,5,6,7  │
│                       │
│                       │
│                       │
│                       │
└───────────────────────┘ 

通过添加 with_sharding_constraint,我们限制了输出的分片。除了尊重特定中间变量的注释外,编译器还会使用注释来决定其他值的分片。

经常的好做法是注释计算的输出,例如根据值最终如何被使用来注释它们。


JAX 中文文档(六)(5)https://developer.aliyun.com/article/1559685

相关文章
|
4月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
29 1
|
4月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
39 1
|
4月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
36 0
JAX 中文文档(六)(1)
|
4月前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
65 0
|
4月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
33 0
|
4月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
31 0
|
4月前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
35 0
|
4月前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
29 0
|
4月前
|
存储 并行计算 开发工具
JAX 中文文档(十)(1)
JAX 中文文档(十)
46 0
|
4月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
51 0