JAX 中文文档(六)(4)https://developer.aliyun.com/article/1559684
示例:神经网络
⚠️ 警告:以下内容旨在简单演示使用 jax.Array
进行自动分片传播,但可能不反映实际示例的最佳实践。 例如,实际示例可能需要更多使用 with_sharding_constraint
。
我们可以利用 jax.device_put
和 jax.jit
的计算跟随分片特性来并行化神经网络中的计算。以下是基于这种基本神经网络的一些简单示例:
import jax import jax.numpy as jnp
def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.maximum(outputs, 0) return outputs def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss) gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out): k1, k2 = jax.random.split(key) W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) b = jax.random.normal(k2, (n_out,)) return W, b def init_model(key, layer_sizes, batch_size): key, *keys = jax.random.split(key, len(layer_sizes)) params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) key, *keys = jax.random.split(key, 3) inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) return params, (inputs, targets) layer_sizes = [784, 8192, 8192, 8192, 10] batch_size = 8192 params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8 路批数据并行
sharding = PositionalSharding(jax.devices()).reshape(8, 1)
batch = jax.device_put(batch, sharding) params = jax.device_put(params, sharding.replicate())
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5 for _ in range(30): grads = gradfun(params, batch) params = [(W - step_size * dW, b - step_size * db) for (W, b), (dW, db) in zip(params, grads)] print(loss_jit(params, batch))
10.760101
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
5 loops, best of 5: 26.3 ms per loop
batch_single = jax.device_put(batch, jax.devices()[0]) params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
5 loops, best of 5: 122 ms per loop
4 路批数据并行和 2 路模型张量并行
sharding = sharding.reshape(4, 2)
batch = jax.device_put(batch, sharding.replicate(1)) jax.debug.visualize_array_sharding(batch[0]) jax.debug.visualize_array_sharding(batch[1])
┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 4,5│ ├───────┤ │TPU 6,7│ └───────┘ ┌───────┐ │TPU 0,1│ ├───────┤ │TPU 2,3│ ├───────┤ │TPU 4,5│ ├───────┤ │TPU 6,7│ └───────┘
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params W1 = jax.device_put(W1, sharding.replicate()) b1 = jax.device_put(b1, sharding.replicate()) W2 = jax.device_put(W2, sharding.replicate(0)) b2 = jax.device_put(b2, sharding.replicate(0)) W3 = jax.device_put(W3, sharding.replicate(0).T) b3 = jax.device_put(b3, sharding.replicate()) W4 = jax.device_put(W4, sharding.replicate()) b4 = jax.device_put(b4, sharding.replicate()) params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘
jax.debug.visualize_array_sharding(W3)
┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
print(loss_jit(params, batch))
10.760103
step_size = 1e-5 for _ in range(30): grads = gradfun(params, batch) params = [(W - step_size * dW, b - step_size * db) for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752466
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params jax.debug.visualize_array_sharding(W2) jax.debug.visualize_array_sharding(W3)
┌───────────┬───────────┐ │ │ │ │ │ │ │ │ │ │ │ │ │TPU 0,2,4,6│TPU 1,3,5,7│ │ │ │ │ │ │ │ │ │ │ │ │ └───────────┴───────────┘ ┌───────────────────────┐ │ │ │ TPU 0,2,4,6 │ │ │ │ │ ├───────────────────────┤ │ │ │ TPU 1,3,5,7 │ │ │ │ │ └───────────────────────┘
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
10 loops, best of 10: 30.5 ms per loop
锐利的部分
生成随机数
JAX 自带一个功能强大且确定性的 随机数生成器。它支持 jax.random
模块 中的各种采样函数,如 jax.random.uniform
。
JAX 的随机数是由基于计数器的 PRNG 生成的,因此原则上,随机数生成应该是对计数器值的纯映射。原则上,纯映射是一个可以轻松分片的操作。它不应需要跨设备通信,也不应需要设备间的冗余计算。
然而,由于历史原因,现有的稳定 RNG 实现并非自动可分片。
考虑以下示例,其中一个函数绘制随机均匀数并将其逐元素添加到输入中:
@jax.jit def f(key, x): numbers = jax.random.uniform(key, x.shape) return x + numbers key = jax.random.key(42) x_sharding = jax.sharding.PositionalSharding(jax.devices()) x = jax.device_put(jnp.arange(24), x_sharding)
在分区输入上,函数 f
生成的输出也是分区的:
jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
但是,如果我们检查 f
在这个分区输入上的编译计算,我们会发现它确实涉及一些通信:
f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True
解决这个问题的一种方法是使用实验性升级标志 jax_threefry_partitionable
配置 JAX。启用该标志后,编译计算中的“集体排列”操作现在已经消失:
jax.config.update('jax_threefry_partitionable', True) f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False
输出仍然是分区的:
jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
然而,jax_threefry_partitionable
选项的一个注意事项是,即使是由相同随机密钥生成的,使用该标志设置后生成的随机值可能与未设置标志时不同:
jax.config.update('jax_threefry_partitionable', False) print('Stable:') print(f(key, x)) print() jax.config.update('jax_threefry_partitionable', True) print('Partitionable:') print(f(key, x))
Stable: [ 0.72503686 1.8532515 2.983416 3.083253 4.0332246 5.4782867 6.1720605 7.6900277 8.602836 9.810046 10.861367 11.907651 12.330483 13.456195 14.808557 15.960099 16.067581 17.739723 18.335474 19.46401 20.390276 21.116539 22.858128 23.223194 ] Partitionable: [ 0.48870957 1.6797972 2.6162715 3.561016 4.4506445 5.585866 6.0748096 7.775133 8.698959 9.818634 10.350306 11.87282 12.925881 13.86013 14.477554 15.818481 16.711355 17.586697 18.073738 19.777622 20.404566 21.119123 22.026257 23.63918 ]
在 jax_threefry_partitionable
模式下,JAX 的 PRNG 保持确定性,但其实现是新的(并且正在开发中)。为给定密钥生成的随机值在特定的 JAX 版本(或 main
分支上的特定提交)中将保持相同,但在不同版本之间可能会有所变化。
= jax.device_put(jnp.arange(24), x_sharding)
在分区输入上,函数 `f` 生成的输出也是分区的: ```py jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
但是,如果我们检查 f
在这个分区输入上的编译计算,我们会发现它确实涉及一些通信:
f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? True
解决这个问题的一种方法是使用实验性升级标志 jax_threefry_partitionable
配置 JAX。启用该标志后,编译计算中的“集体排列”操作现在已经消失:
jax.config.update('jax_threefry_partitionable', True) f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text())
Communicating? False
输出仍然是分区的:
jax.debug.visualize_array_sharding(f(key, x))
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐ │ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │ └───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
然而,jax_threefry_partitionable
选项的一个注意事项是,即使是由相同随机密钥生成的,使用该标志设置后生成的随机值可能与未设置标志时不同:
jax.config.update('jax_threefry_partitionable', False) print('Stable:') print(f(key, x)) print() jax.config.update('jax_threefry_partitionable', True) print('Partitionable:') print(f(key, x))
Stable: [ 0.72503686 1.8532515 2.983416 3.083253 4.0332246 5.4782867 6.1720605 7.6900277 8.602836 9.810046 10.861367 11.907651 12.330483 13.456195 14.808557 15.960099 16.067581 17.739723 18.335474 19.46401 20.390276 21.116539 22.858128 23.223194 ] Partitionable: [ 0.48870957 1.6797972 2.6162715 3.561016 4.4506445 5.585866 6.0748096 7.775133 8.698959 9.818634 10.350306 11.87282 12.925881 13.86013 14.477554 15.818481 16.711355 17.586697 18.073738 19.777622 20.404566 21.119123 22.026257 23.63918 ]
在 jax_threefry_partitionable
模式下,JAX 的 PRNG 保持确定性,但其实现是新的(并且正在开发中)。为给定密钥生成的随机值在特定的 JAX 版本(或 main
分支上的特定提交)中将保持相同,但在不同版本之间可能会有所变化。