JAX 中文文档(六)(5)

简介: JAX 中文文档(六)

JAX 中文文档(六)(4)https://developer.aliyun.com/article/1559684


示例:神经网络

⚠️ 警告:以下内容旨在简单演示使用 jax.Array 进行自动分片传播,但可能不反映实际示例的最佳实践。 例如,实际示例可能需要更多使用 with_sharding_constraint

我们可以利用 jax.device_putjax.jit 的计算跟随分片特性来并行化神经网络中的计算。以下是基于这种基本神经网络的一些简单示例:

import jax
import jax.numpy as jnp 
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.maximum(outputs, 0)
  return outputs
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) 
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss)) 
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b
def init_model(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
    return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) 

8 路批数据并行

sharding = PositionalSharding(jax.devices()).reshape(8, 1) 
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, sharding.replicate()) 
loss_jit(params, batch) 
Array(23.469475, dtype=float32) 
step_size = 1e-5
for _ in range(30):
  grads = gradfun(params, batch)
  params = [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch)) 
10.760101 
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready() 
5 loops, best of 5: 26.3 ms per loop 
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0]) 
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready() 
5 loops, best of 5: 122 ms per loop 

4 路批数据并行和 2 路模型张量并行

sharding = sharding.reshape(4, 2) 
batch = jax.device_put(batch, sharding.replicate(1))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1]) 
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘
┌───────┐
│TPU 0,1│
├───────┤
│TPU 2,3│
├───────┤
│TPU 4,5│
├───────┤
│TPU 6,7│
└───────┘ 
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, sharding.replicate())
b1 = jax.device_put(b1, sharding.replicate())
W2 = jax.device_put(W2, sharding.replicate(0))
b2 = jax.device_put(b2, sharding.replicate(0))
W3 = jax.device_put(W3, sharding.replicate(0).T)
b3 = jax.device_put(b3, sharding.replicate())
W4 = jax.device_put(W4, sharding.replicate())
b4 = jax.device_put(b4, sharding.replicate())
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) 
jax.debug.visualize_array_sharding(W2) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘ 
jax.debug.visualize_array_sharding(W3) 
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘ 
print(loss_jit(params, batch)) 
10.760103 
step_size = 1e-5
for _ in range(30):
    grads = gradfun(params, batch)
    params = [(W - step_size * dW, b - step_size * db)
              for (W, b), (dW, db) in zip(params, grads)] 
print(loss_jit(params, batch)) 
10.752466 
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3) 
┌───────────┬───────────┐
│           │           │
│           │           │
│           │           │
│           │           │
│TPU 0,2,4,6│TPU 1,3,5,7│
│           │           │
│           │           │
│           │           │
│           │           │
└───────────┴───────────┘
┌───────────────────────┐
│                       │
│      TPU 0,2,4,6      │
│                       │
│                       │
├───────────────────────┤
│                       │
│      TPU 1,3,5,7      │
│                       │
│                       │
└───────────────────────┘ 
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() 
10 loops, best of 10: 30.5 ms per loop 

锐利的部分

生成随机数

JAX 自带一个功能强大且确定性的 随机数生成器。它支持 jax.random 模块 中的各种采样函数,如 jax.random.uniform

JAX 的随机数是由基于计数器的 PRNG 生成的,因此原则上,随机数生成应该是对计数器值的纯映射。原则上,纯映射是一个可以轻松分片的操作。它不应需要跨设备通信,也不应需要设备间的冗余计算。

然而,由于历史原因,现有的稳定 RNG 实现并非自动可分片。

考虑以下示例,其中一个函数绘制随机均匀数并将其逐元素添加到输入中:

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers
key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24), x_sharding) 

在分区输入上,函数 f 生成的输出也是分区的:

jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

但是,如果我们检查 f 在这个分区输入上的编译计算,我们会发现它确实涉及一些通信:

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? True 

解决这个问题的一种方法是使用实验性升级标志 jax_threefry_partitionable 配置 JAX。启用该标志后,编译计算中的“集体排列”操作现在已经消失:

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? False 

输出仍然是分区的:

jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

然而,jax_threefry_partitionable 选项的一个注意事项是,即使是由相同随机密钥生成的,使用该标志设置后生成的随机值可能与未设置标志时不同

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()
jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x)) 
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]
Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ] 

jax_threefry_partitionable 模式下,JAX 的 PRNG 保持确定性,但其实现是新的(并且正在开发中)。为给定密钥生成的随机值在特定的 JAX 版本(或 main 分支上的特定提交)中将保持相同,但在不同版本之间可能会有所变化。

= jax.device_put(jnp.arange(24), x_sharding)

在分区输入上,函数 `f` 生成的输出也是分区的:
```py
jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

但是,如果我们检查 f 在这个分区输入上的编译计算,我们会发现它确实涉及一些通信:

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? True 

解决这个问题的一种方法是使用实验性升级标志 jax_threefry_partitionable 配置 JAX。启用该标志后,编译计算中的“集体排列”操作现在已经消失:

jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text()) 
Communicating? False 

输出仍然是分区的:

jax.debug.visualize_array_sharding(f(key, x)) 
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘ 

然而,jax_threefry_partitionable 选项的一个注意事项是,即使是由相同随机密钥生成的,使用该标志设置后生成的随机值可能与未设置标志时不同

jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
print()
jax.config.update('jax_threefry_partitionable', True)
print('Partitionable:')
print(f(key, x)) 
Stable:
[ 0.72503686  1.8532515   2.983416    3.083253    4.0332246   5.4782867
  6.1720605   7.6900277   8.602836    9.810046   10.861367   11.907651
 12.330483   13.456195   14.808557   15.960099   16.067581   17.739723
 18.335474   19.46401    20.390276   21.116539   22.858128   23.223194  ]
Partitionable:
[ 0.48870957  1.6797972   2.6162715   3.561016    4.4506445   5.585866
  6.0748096   7.775133    8.698959    9.818634   10.350306   11.87282
 12.925881   13.86013    14.477554   15.818481   16.711355   17.586697
 18.073738   19.777622   20.404566   21.119123   22.026257   23.63918   ] 

jax_threefry_partitionable 模式下,JAX 的 PRNG 保持确定性,但其实现是新的(并且正在开发中)。为给定密钥生成的随机值在特定的 JAX 版本(或 main 分支上的特定提交)中将保持相同,但在不同版本之间可能会有所变化。

相关文章
|
9天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
12 1
|
9天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
13 1
|
9天前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
15 0
|
9天前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
10 0
|
9天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
10 0
|
9天前
|
机器学习/深度学习 异构计算 AI芯片
JAX 中文文档(七)(4)
JAX 中文文档(七)
7 0
|
9天前
|
机器学习/深度学习 缓存 API
JAX 中文文档(一)(4)
JAX 中文文档(一)
12 0
|
9天前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
11 0
|
9天前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
8 0
|
9天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0