JAX 中文文档(七)(2)https://developer.aliyun.com/article/1559697
玩具示例
我们如何在实践中使用 shard_map
和集体通信?这些例子虽然简单,但提供了一些思路。
矩阵乘法
并行化矩阵乘法对于扩展深度学习模型至关重要,无论是用于训练还是推断。当 jax.jit
自动并行化矩阵乘法时,它可以使用几种不同的策略,这取决于矩阵大小、硬件细节和其他因素。我们如何更明确地编写一些使用 shard_map
并行化的例程?如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',)) def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:all-gather
在一侧
考虑执行一个矩阵乘法,在这个过程中我们在其主(非收缩)维度上分片左侧参数(可以考虑:参数):
lhs_spec = P('i', None) lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
并且我们在其收缩维度上分片右侧参数(可以考虑:激活),输出也类似分片:
rhs_spec = P('i', None) rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
为了执行这个矩阵乘法,我们可以首先全收集右侧,然后对分片左侧进行本地矩阵乘法:
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather(lhs_block, rhs_block): rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) return lhs_block @ rhs
out = matmul_allgather(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
这很棒,但我们这里没有计算/通信重叠:在我们可以开始矩阵乘法之前,我们需要 all_gather
完成。这里是使用相同代码的性能分析,但在更大的示例形状上 (lhs
为 (8192, 8192)
,rhs
为 (8192, 1024)
):
如果我们不是调用 all_gather
,而是基本上在我们的 ppermute
实现中内联我们上面的 all_gather
,那么我们可以获得计算/通信重叠,然后交错进行收集排列步骤与本地矩阵乘法的步骤:
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) B = lhs_block.shape[1] // size lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1) out_block = lhs_blocks(idx) @ rhs_block for i in range(1, size): rhs_block = shift(rhs_block) out_block += lhs_blocks((idx - i) % size) @ rhs_block return out_block
out = matmul_allgather_overlapped(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
此实现允许在通信和计算之间重叠,并且还避免在每个设备上聚合大量中间数据。但在 TPU 上,通过沿环的一个方向仅置换,仅使用一半的互连带宽。要双向置换,我们只需将块分成两半,并将每半分别发送到每个方向:
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) shift_dn = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) B = lhs_block.shape[1] // size // 2 # half-size blocks lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1) rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0) out_block = lhs_blocks(idx, 0) @ rhs_block_lo out_block += lhs_blocks(idx, 1) @ rhs_block_hi for i in range(1, size): rhs_block_lo = shift_up(rhs_block_lo) rhs_block_hi = shift_dn(rhs_block_hi) out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
在实践中,为了减少编译时间,我们可能会将这些内容合并到jax.lax.fori_loop
中。我们可能还涉及额外的轴并行化。
示例 2:psum_scatter
结果
另一个我们可以开始的分片方法是,将lhs
和rhs
沿其收缩维度进行分片,输出再次像rhs
一样进行分片:
lhs_spec = P(None, 'i') lhs = device_put(lhs, lhs_spec) rhs_spec = P('i', None) rhs = device_put(rhs, rhs_spec)
在这里,我们可以使用reduce_scatter
来执行分片上的收缩求和:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter(lhs_block, rhs_block): out_summand = lhs_block @ rhs_block return jax.lax.psum_scatter(out_summand, 'i', tiled=True) out = matmul_psumscatter(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
但散射通信必须等待整个本地矩阵乘法完成后才能开始。为了实现通信/计算重叠,我们可以内联psum_scatter
的ppermute
实现,然后将通信步骤与本地矩阵乘法交错进行:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis out_summand = lhs_block[(idx + 1) % size] @ rhs_block for i in range(1, size): out_summand = shift(out_summand) out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
如前例所示,为了充分利用 TPU 上的互连,我们将运行一个双向版本:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) shift_dn = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) B = lhs_block.shape[0] // size // 2 # half-size blocks lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0) out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block for i in range(1, size): out_summand_lo = shift_up(out_summand_lo) out_summand_hi = shift_dn(out_summand_hi) out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
True
神经网络
我们可以使用shard_map
来并行计算神经网络中的计算,可以单独使用,也可以与jax.jit
中的自动分区组合使用。本节基于此玩具神经网络和随机数据提供了一些示例:
import jax import jax.numpy as jnp def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jax.nn.relu(outputs) return outputs def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out): k1, k2 = jax.random.split(key) W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) b = jax.random.normal(k2, (n_out,)) return W, b def init(key, layer_sizes, batch_size): key, *keys = jax.random.split(key, len(layer_sizes)) params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) key, *keys = jax.random.split(key, 3) inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8] batch_size = 32 params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
将这些示例与纯粹的“分布式数组和自动分区”文档中的自动分区示例进行比较。在这些自动分区示例中,我们无需编辑模型函数即可使用不同的并行化策略,而在shard_map
中,我们经常需要这样做。
8 路批次数据并行
最简单的多设备并行策略是将输入和目标的批次在多个设备上进行分片,将参数复制到这些设备上,并并行应用模型于数据的这些分片。为了评估总损失,设备只需在末尾进行标量大小的全约和求和。(为了评估损失的梯度,设备必须在后向传播中执行参数梯度的全约和求和。)
from functools import partial from jax.sharding import NamedSharding, Mesh, PartitionSpec as P from jax.experimental.shard_map import shard_map from jax.experimental import mesh_utils devices = mesh_utils.create_device_mesh((8,)) # replicate initial params on all devices, shard data batch over devices mesh = Mesh(devices, ('batch',)) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P())) # adapt the loss function to sum the losses across devices def loss_dp(params, batch): @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) def loss_spmd(local_batch): inputs, targets = local_batch predictions = predict(params, inputs) # use reference 'predict` local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'batch') return loss_spmd(batch)
我们可以检查损失及其梯度是否与参考(基础)模型匹配:
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_dp)(params, batch))
22.779888 22.779888
def allclose(a, b): return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_dp))(params, batch)))
True
我们可以打印编译器 IR 以检查梯度计算,并验证在预期位置进行的集体全约和求和操作:在前向传播的末尾计算损失值时,以及在后向传播中计算总参数梯度时。
8 路完全分片数据并行(FSDP)
另一种策略是在设备上额外对参数进行分片,在需要完整值进行jnp.dot
或偏置添加时进行全部聚集。由于我们每次只在本地设备内存中保留一个完整的参数,而不像前面的 DP 示例中在所有设备内存中保留所有参数,这样我们可以释放出大量内存,用于更大的模型或更大的批处理大小。并且由于 XLA 会重叠计算和设备间通信,所以墙钟时间不会受影响。
因此,现在我们需要在两个地方进行集体操作:模型预测函数predict
需要在使用参数之前对其进行全部聚集,而与 DP 情况一样,损失函数需要对本地损失进行求和以计算总损失。
还有一项我们需要的内容:我们不希望在反向传播中存储从前向传播中完全聚集的参数。相反,我们希望在反向传播时再次聚集它们。我们可以通过使用jax.remat
与自定义策略(或custom_vjp
)来表达这一点,尽管 XLA 通常会自动进行该重现操作。
这种通用的FSDP 方法类似于权重更新分片(WUS)和ZeRO-3。
# shard data batch *and params* over devices mesh = Mesh(devices, ('batch',)) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P('batch'))) # adapt the prediction function to gather weights just before their use, # and to re-gather them on the backward pass (rather than saving them) @partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') def predict_fsdp(params_frag, inputs): for W_frag, b_frag in params_frag: W = jax.lax.all_gather(W_frag, 'batch', tiled=True) b = jax.lax.all_gather(b_frag, 'batch', tiled=True) outputs = jnp.dot(inputs, W) + b inputs = jax.nn.relu(outputs) return outputs def loss_fsdp(params, batch): @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) def loss_spmd(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp(local_params, inputs) local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'batch') return loss_spmd(params, batch)
再次,我们可以检查损失及其梯度是否与参考模型匹配:
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_fsdp)(params, batch)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888 22.779888 True
8 路张量并行性(TP)
通常我们不单独使用张量模型并行性,但单独看它可以作为并行矩阵乘法的一个良好热身。这也是在库函数中使用shard_map
的一个良好示例,被调用于基于jit
的大型计算中。
并行化的理念是我们将保持数据/激活沿其特征轴分片(而不是批处理轴),并且我们将类似地在输入特征轴上分片权重矩阵(和在其特征轴上的偏置)。然后,为了执行并行矩阵乘法,我们将执行本地矩阵乘法,然后进行psum_scatter
以对本地结果求和并高效地分散结果的分片。
devices = mesh_utils.create_device_mesh((8,)) mesh = Mesh(devices, ('feats',)) batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) params = jax.device_put(params, NamedSharding(mesh, P('feats'))) def predict_tp(params, inputs): for W, b in params: outputs = gemm_tp(inputs, W, b) inputs = jax.nn.relu(outputs) return outputs @partial(shard_map, mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), out_specs=P(None, 'feats')) def gemm_tp(inputs, W, b): block_result = jnp.dot(inputs, W) return jax.lax.psum_scatter(block_result, 'feats', scatter_dimension=1, tiled=True) + b def loss_tp(params, batch): inputs, targets = batch predictions = predict_tp(params, inputs) return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,在顶层使用shard_map
我们可以将这些策略组合在一起,使用多轴并行性。
devices = mesh_utils.create_device_mesh((4, 2)) mesh = Mesh(devices, ('batch', 'feats')) batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) # mostly same as previous predict_fsdp definition, except we call gemm_tp @partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') def predict_fsdp_tp(params_frag, inputs): for W_frag, b_frag in params_frag: W = jax.lax.all_gather(W_frag, 'batch', tiled=True) b = jax.lax.all_gather(b_frag, 'batch', tiled=True) block_result = jnp.dot(inputs, W) outputs = jax.lax.psum_scatter(block_result, 'feats', scatter_dimension=1, tiled=True) + b inputs = jax.nn.relu(outputs) return outputs @partial(shard_map, mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')), out_specs=P()) def loss_fsdp_tp(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp_tp(local_params, inputs) sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats') return jax.lax.pmean(jnp.mean(sq_err), 'batch')
注意我们必须进行两次集体归约的方式:一次是在'feats'
上,另一次是在'batch'
上。在纯 TP 示例中,我们没有显式写出'feats'
归约,因为我们仅在gemm_tp
内部使用了shard_map
;在调用loss_tp
时,编译器会自动将我们对jnp.sum
的使用转换为根据predict_tp
返回的分片结果执行所需的psum
。
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_fsdp_tp)(params_, batch_)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886 22.779886 True
SPMD 管道并行性(PP)
通过流水线并行,我们的目标是并行评估网络中不同深度层的层。例如,一个设备可能计算第一层的应用,而另一个设备计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,这个过程重复进行。一般来说,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多个层。
使用 SPMD 流水线,我们利用网络中大多数层应用计算的事实,只是参数值不同。特别是,我们可以堆叠除了第一层和最后一层之外的所有参数,然后使用shard_map
将这些层参数块映射到管道阶段。然后我们使用jax.lax.ppermute
集合来沿并行管道向下移动数据。
这种特定的流水线策略本质上是GPipe 策略。有几种变体以及相当不同的策略,哪一种适合取决于各阶段之间的网络速度和批量大小。但是在本教程中,我们将专注于只有一种策略。
首先,我们选择一些流水线参数:
L = len(params) - 2 # num layers, excluding first and last N = batch_size # batch size F = params[0][0].shape[1] # num features # choose some pipeline parameters S = 2 # number of stages B = 8 # size of each microbatch assert L % S == 0, "S (number of stages) must divide L (number of inner layers)" # compute some useful quantities M, ragged = divmod(N, B) # M is number of microbatches assert not ragged, "B (size of each microbatch) must divide total batch size" K, ragged = divmod(M, S) # K is microbatches per stage assert not ragged, "S (number of stages) must divide number of microbatches" print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total') print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total 8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',)) def predict_pp(params, inputs): (W_first, b_first), inner_params, (W_last, b_last) = params inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first) inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]), inner_params, inputs) outputs = jnp.dot(inputs, W_last) + b_last return outputs @partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), out_specs=P()) def loss_pp(params, batch): inputs, targets = batch predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1) local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs): stage = jax.lax.axis_index('stages') outputs = jnp.zeros_like(inputs) * jnp.nan state = jnp.zeros((L // S, B, F)) * jnp.nan for i in range(M+L-1): state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0])) state = jax.vmap(fn)(stage_params, state) outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K])) state, inputs, outputs = shift(i, state, inputs, outputs) outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)]) return outputs def shift(i, state, inputs, outputs): sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)]) state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1)) if (i % K) == (-1 % K): inputs = sh(inputs, +1) if ((i-L+1) % K) == (-1 % K): outputs = sh(outputs, +1) return state, inputs, outputs
first_params, *inner_params, last_params = params Ws, bs = zip(*inner_params) params_stacked = jnp.stack(Ws), jnp.stack(bs) first_params = jax.device_put(first_params, NamedSharding(mesh, P())) params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages'))) last_params = jax.device_put(last_params, NamedSharding(mesh, P())) params_ = first_params, params_stacked, last_params batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_pp)(params_, batch_))
22.779886 22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash
JAX 中文文档(七)(4)https://developer.aliyun.com/article/1559700