JAX 中文文档(六)(3)https://developer.aliyun.com/article/1559683
计算遵循数据分片并自动并行化
使用分片输入数据,编译器可以给我们并行计算。特别是,用 jax.jit
装饰的函数可以在分片数组上操作,而无需将数据复制到单个设备上。相反,计算遵循分片:基于输入数据的分片,编译器决定中间结果和输出值的分片,并并行评估它们,必要时甚至插入通信操作。
例如,最简单的计算是逐元素的:
from jax.experimental import mesh_utils from jax.sharding import PositionalSharding sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.device_put(x, sharding.reshape(4, 2)) print('input sharding:') jax.debug.visualize_array_sharding(x) y = jnp.sin(x) print('output sharding:') jax.debug.visualize_array_sharding(y)
input sharding: ┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘ output sharding: ┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里对于逐元素操作 jnp.sin
,编译器选择了输出分片与输入相同。此外,编译器自动并行化计算,因此每个设备都可以并行计算其输出片段。
换句话说,即使我们将 jnp.sin
的计算写成单台机器执行,编译器也会为我们拆分计算并在多个设备上执行。
我们不仅可以对逐元素操作执行相同操作。考虑使用分片输入的矩阵乘法:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1)) z = jax.device_put(x, sharding.reshape(4, 2).replicate(0)) print('lhs sharding:') jax.debug.visualize_array_sharding(y) print('rhs sharding:') jax.debug.visualize_array_sharding(z) w = jnp.dot(y, z) print('out sharding:') jax.debug.visualize_array_sharding(w)
lhs sharding: ┌───────────────────────┐ │ TPU 0,1 │ ├───────────────────────┤ │ TPU 2,3 │ ├───────────────────────┤ │ TPU 6,7 │ ├───────────────────────┤ │ TPU 4,5 │ └───────────────────────┘ rhs sharding: ┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘ out sharding: ┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
这里编译器选择了输出分片,以便最大化并行计算:无需通信,每个设备已经具有计算其输出分片所需的输入分片。
我们如何确保它实际上是并行运行的?我们可以进行简单的时间实验:
x_single = jax.device_put(x, jax.devices()[0]) jax.debug.visualize_array_sharding(x_single)
┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0 │ │ │ │ │ │ │ │ │ └───────────────────────┘
np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
5 loops, best of 5: 19.3 ms per loop
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
5 loops, best of 5: 3.25 ms per loop
即使复制一个分片的 Array
,也会产生具有输入分片的结果:
w_copy = jnp.copy(w) jax.debug.visualize_array_sharding(w_copy)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘
因此,当我们使用 jax.device_put
明确分片数据并对该数据应用函数时,编译器会尝试并行化计算并决定输出分片。这种对分片数据的策略是JAX 遵循显式设备放置策略的泛化。
当明确分片不一致时,JAX 会报错
但是如果计算的两个参数在不同的设备组上明确放置,或者设备顺序不兼容,会发生错误:
import textwrap from termcolor import colored def print_exception(e): name = colored(f'{type(e).__name__}', 'red') print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = PositionalSharding(jax.devices()[:4]) sharding2 = PositionalSharding(jax.devices()[4:]) y = jax.device_put(x, sharding1.reshape(2, 2)) z = jax.device_put(x, sharding2.reshape(2, 2)) try: y + z except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0, 1, 2, 3] on platform TPU and another array's device ids [4, 5, 6, 7] on platform TPU
devices = jax.devices() permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]] sharding1 = PositionalSharding(devices) sharding2 = PositionalSharding(permuted_devices) y = jax.device_put(x, sharding1.reshape(4, 2)) z = jax.device_put(x, sharding2.reshape(4, 2)) try: y + z except ValueError as e: print_exception(e)
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on platform TPU
我们说通过 jax.device_put
明确放置或分片的数组已经锁定在它们的设备上,因此不会自动移动。请查看 设备放置常见问题解答 获取更多信息。
当数组没有使用 jax.device_put
明确放置或分片时,它们会放置在默认设备上并未锁定。与已锁定数组不同,未锁定数组可以自动移动和重新分片:也就是说,未锁定数组可以作为计算的参数,即使其他参数明确放置在不同的设备上。
例如,jnp.zeros
、jnp.arange
和 jnp.array
的输出都是未锁定的:
y = jax.device_put(x, sharding1.reshape(4, 2)) y + jnp.ones_like(y) y + jnp.arange(y.size).reshape(y.shape) print('no error!')
no error!
限制在 jit
代码中的中间片段
虽然编译器将尝试决定函数的中间值和输出应如何分片,但我们还可以使用 jax.lax.with_sharding_constraint
来给它提供提示。使用 jax.lax.with_sharding_constraint
类似于 jax.device_put
,不同之处在于我们在分阶段函数(即 jit
装饰的函数)内部使用它:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
x = jax.random.normal(jax.random.key(0), (8192, 8192)) x = jax.device_put(x, sharding.reshape(4, 2))
@jax.jit def f(x): x = x + 1 y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4)) return y
jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘ ┌───────┬───────┬───────┬───────┐ │ │ │ │ │ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ │ │ │ │ │ │ │ │ │ │ ├───────┼───────┼───────┼───────┤ │ │ │ │ │ │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │ │ │ │ │ │ │ │ │ │ │ └───────┴───────┴───────┴───────┘
@jax.jit def f(x): x = x + 1 y = jax.lax.with_sharding_constraint(x, sharding.replicate()) return y
jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y)
┌──────────┬──────────┐ │ TPU 0 │ TPU 1 │ ├──────────┼──────────┤ │ TPU 2 │ TPU 3 │ ├──────────┼──────────┤ │ TPU 6 │ TPU 7 │ ├──────────┼──────────┤ │ TPU 4 │ TPU 5 │ └──────────┴──────────┘ ┌───────────────────────┐ │ │ │ │ │ │ │ │ │ TPU 0,1,2,3,4,5,6,7 │ │ │ │ │ │ │ │ │ └───────────────────────┘
通过添加 with_sharding_constraint
,我们限制了输出的分片。除了尊重特定中间变量的注释外,编译器还会使用注释来决定其他值的分片。
经常的好做法是注释计算的输出,例如根据值最终如何被使用来注释它们。
JAX 中文文档(六)(5)https://developer.aliyun.com/article/1559685