JAX 中文文档(七)(2)https://developer.aliyun.com/article/1559697
我们如何在实践中使用 shard_map
并行化矩阵乘法对于扩展深度学习模型至关重要,无论是用于训练还是推断。当 jax.jit
自动并行化矩阵乘法时,它可以使用几种不同的策略,这取决于矩阵大小、硬件细节和其他因素。我们如何更明确地编写一些使用 shard_map
并行化的例程?如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?
import jax import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices()[:4], ('i',)) def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec))
示例 1:all-gather
lhs_spec = P('i', None) lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec)
rhs_spec = P('i', None) rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec)
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather(lhs_block, rhs_block): rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True) return lhs_block @ rhs
out = matmul_allgather(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
这很棒,但我们这里没有计算/通信重叠:在我们可以开始矩阵乘法之前,我们需要 all_gather
完成。这里是使用相同代码的性能分析,但在更大的示例形状上 (lhs
为 (8192, 8192)
为 (8192, 1024)
如果我们不是调用 all_gather
,而是基本上在我们的 ppermute
实现中内联我们上面的 all_gather
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) B = lhs_block.shape[1] // size lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1) out_block = lhs_blocks(idx) @ rhs_block for i in range(1, size): rhs_block = shift(rhs_block) out_block += lhs_blocks((idx - i) % size) @ rhs_block return out_block
out = matmul_allgather_overlapped(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
此实现允许在通信和计算之间重叠,并且还避免在每个设备上聚合大量中间数据。但在 TPU 上,通过沿环的一个方向仅置换,仅使用一半的互连带宽。要双向置换,我们只需将块分成两半,并将每半分别发送到每个方向:
@jax.jit @partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_allgather_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) shift_dn = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) B = lhs_block.shape[1] // size // 2 # half-size blocks lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1) rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0) out_block = lhs_blocks(idx, 0) @ rhs_block_lo out_block += lhs_blocks(idx, 1) @ rhs_block_hi for i in range(1, size): rhs_block_lo = shift_up(rhs_block_lo) rhs_block_hi = shift_dn(rhs_block_hi) out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi return out_block
out = matmul_allgather_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
示例 2:psum_scatter
lhs_spec = P(None, 'i') lhs = device_put(lhs, lhs_spec) rhs_spec = P('i', None) rhs = device_put(rhs, rhs_spec)
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter(lhs_block, rhs_block): out_summand = lhs_block @ rhs_block return jax.lax.psum_scatter(out_summand, 'i', tiled=True) out = matmul_psumscatter(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1]) # split 1st axis out_summand = lhs_block[(idx + 1) % size] @ rhs_block for i in range(1, size): out_summand = shift(out_summand) out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block return out_summand
out = matmul_psumscatter_overlapped(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
如前例所示,为了充分利用 TPU 上的互连,我们将运行一个双向版本:
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec), out_specs=rhs_spec) def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block): size = jax.lax.psum(1, 'i') idx = jax.lax.axis_index('i') shift_up = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i + 1) % size) for i in range(size)]) shift_dn = partial(jax.lax.ppermute, axis_name='i', perm=[(i, (i - 1) % size) for i in range(size)]) B = lhs_block.shape[0] // size // 2 # half-size blocks lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0) out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block for i in range(1, size): out_summand_lo = shift_up(out_summand_lo) out_summand_hi = shift_dn(out_summand_hi) out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block return jnp.concatenate([out_summand_lo, out_summand_hi])
out = matmul_psumscatter_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
import jax import jax.numpy as jnp def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jax.nn.relu(outputs) return outputs def loss(params, batch): inputs, targets = batch predictions = predict(params, inputs) return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
def init_layer(key, n_in, n_out): k1, k2 = jax.random.split(key) W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) b = jax.random.normal(k2, (n_out,)) return W, b def init(key, layer_sizes, batch_size): key, *keys = jax.random.split(key, len(layer_sizes)) params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) key, *keys = jax.random.split(key, 3) inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) return params, (inputs, targets)
layer_sizes = [784, 128, 128, 128, 128, 128, 8] batch_size = 32 params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
8 路批次数据并行
from functools import partial from jax.sharding import NamedSharding, Mesh, PartitionSpec as P from jax.experimental.shard_map import shard_map from jax.experimental import mesh_utils devices = mesh_utils.create_device_mesh((8,)) # replicate initial params on all devices, shard data batch over devices mesh = Mesh(devices, ('batch',)) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P())) # adapt the loss function to sum the losses across devices def loss_dp(params, batch): @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P()) def loss_spmd(local_batch): inputs, targets = local_batch predictions = predict(params, inputs) # use reference 'predict` local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'batch') return loss_spmd(batch)
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_dp)(params, batch))
22.779888 22.779888
def allclose(a, b): return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_dp))(params, batch)))
我们可以打印编译器 IR 以检查梯度计算,并验证在预期位置进行的集体全约和求和操作:在前向传播的末尾计算损失值时,以及在后向传播中计算总参数梯度时。
8 路完全分片数据并行(FSDP)
或偏置添加时进行全部聚集。由于我们每次只在本地设备内存中保留一个完整的参数,而不像前面的 DP 示例中在所有设备内存中保留所有参数,这样我们可以释放出大量内存,用于更大的模型或更大的批处理大小。并且由于 XLA 会重叠计算和设备间通信,所以墙钟时间不会受影响。
需要在使用参数之前对其进行全部聚集,而与 DP 情况一样,损失函数需要对本地损失进行求和以计算总损失。
)来表达这一点,尽管 XLA 通常会自动进行该重现操作。
这种通用的FSDP 方法类似于权重更新分片(WUS)和ZeRO-3。
# shard data batch *and params* over devices mesh = Mesh(devices, ('batch',)) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P('batch'))) # adapt the prediction function to gather weights just before their use, # and to re-gather them on the backward pass (rather than saving them) @partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') def predict_fsdp(params_frag, inputs): for W_frag, b_frag in params_frag: W = jax.lax.all_gather(W_frag, 'batch', tiled=True) b = jax.lax.all_gather(b_frag, 'batch', tiled=True) outputs = jnp.dot(inputs, W) + b inputs = jax.nn.relu(outputs) return outputs def loss_fsdp(params, batch): @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P()) def loss_spmd(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp(local_params, inputs) local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'batch') return loss_spmd(params, batch)
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_fsdp)(params, batch)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp))(params, batch)))
22.779888 22.779888 True
8 路张量并行性(TP)
devices = mesh_utils.create_device_mesh((8,)) mesh = Mesh(devices, ('feats',)) batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) params = jax.device_put(params, NamedSharding(mesh, P('feats'))) def predict_tp(params, inputs): for W, b in params: outputs = gemm_tp(inputs, W, b) inputs = jax.nn.relu(outputs) return outputs @partial(shard_map, mesh=mesh, in_specs=(P(None, 'feats'), P('feats', None), P('feats')), out_specs=P(None, 'feats')) def gemm_tp(inputs, W, b): block_result = jnp.dot(inputs, W) return jax.lax.psum_scatter(block_result, 'feats', scatter_dimension=1, tiled=True) + b def loss_tp(params, batch): inputs, targets = batch predictions = predict_tp(params, inputs) return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
FSDP + TP,在顶层使用shard_map
devices = mesh_utils.create_device_mesh((4, 2)) mesh = Mesh(devices, ('batch', 'feats')) batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) # mostly same as previous predict_fsdp definition, except we call gemm_tp @partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather') def predict_fsdp_tp(params_frag, inputs): for W_frag, b_frag in params_frag: W = jax.lax.all_gather(W_frag, 'batch', tiled=True) b = jax.lax.all_gather(b_frag, 'batch', tiled=True) block_result = jnp.dot(inputs, W) outputs = jax.lax.psum_scatter(block_result, 'feats', scatter_dimension=1, tiled=True) + b inputs = jax.nn.relu(outputs) return outputs @partial(shard_map, mesh=mesh, in_specs=(P(('feats', 'batch')), P('batch', 'feats')), out_specs=P()) def loss_fsdp_tp(local_params, local_batch): inputs, targets = local_batch predictions = predict_fsdp_tp(local_params, inputs) sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats') return jax.lax.pmean(jnp.mean(sq_err), 'batch')
上。在纯 TP 示例中,我们没有显式写出'feats'
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_fsdp_tp)(params_, batch_)) print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
22.779886 22.779886 True
SPMD 管道并行性(PP)
使用 SPMD 流水线,我们利用网络中大多数层应用计算的事实,只是参数值不同。特别是,我们可以堆叠除了第一层和最后一层之外的所有参数,然后使用shard_map
这种特定的流水线策略本质上是GPipe 策略。有几种变体以及相当不同的策略,哪一种适合取决于各阶段之间的网络速度和批量大小。但是在本教程中,我们将专注于只有一种策略。
L = len(params) - 2 # num layers, excluding first and last N = batch_size # batch size F = params[0][0].shape[1] # num features # choose some pipeline parameters S = 2 # number of stages B = 8 # size of each microbatch assert L % S == 0, "S (number of stages) must divide L (number of inner layers)" # compute some useful quantities M, ragged = divmod(N, B) # M is number of microbatches assert not ragged, "B (size of each microbatch) must divide total batch size" K, ragged = divmod(M, S) # K is microbatches per stage assert not ragged, "S (number of stages) must divide number of microbatches" print(f'{S} stages, {L // S} layer(s) per stage, {L} pipelined layers total') print(f'{B} examples per microbatch, {M} microbatches total')
2 stages, 2 layer(s) per stage, 4 pipelined layers total 8 examples per microbatch, 4 microbatches total
mesh = Mesh(jax.devices()[:S], ('stages',)) def predict_pp(params, inputs): (W_first, b_first), inner_params, (W_last, b_last) = params inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first) inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]), inner_params, inputs) outputs = jnp.dot(inputs, W_last) + b_last return outputs @partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')), out_specs=P()) def loss_pp(params, batch): inputs, targets = batch predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1) local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) return jax.lax.pmean(local_loss, 'stages')
def spmd_pipeline(fn, stage_params, inputs): stage = jax.lax.axis_index('stages') outputs = jnp.zeros_like(inputs) * jnp.nan state = jnp.zeros((L // S, B, F)) * jnp.nan for i in range(M+L-1): state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0])) state = jax.vmap(fn)(stage_params, state) outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K])) state, inputs, outputs = shift(i, state, inputs, outputs) outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)]) return outputs def shift(i, state, inputs, outputs): sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)]) state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1)) if (i % K) == (-1 % K): inputs = sh(inputs, +1) if ((i-L+1) % K) == (-1 % K): outputs = sh(outputs, +1) return state, inputs, outputs
first_params, *inner_params, last_params = params Ws, bs = zip(*inner_params) params_stacked = jnp.stack(Ws), jnp.stack(bs) first_params = jax.device_put(first_params, NamedSharding(mesh, P())) params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages'))) last_params = jax.device_put(last_params, NamedSharding(mesh, P())) params_ = first_params, params_stacked, last_params batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages')))
print(jax.jit(loss)(params, batch)) print(jax.jit(loss_pp)(params_, batch_))
22.779886 22.779884
_ = jax.jit(jax.grad(loss_pp))(params_, batch_) # don't crash
JAX 中文文档(七)(4)https://developer.aliyun.com/article/1559700