JAX 中文文档(七)(3)

简介: JAX 中文文档(七)

JAX 中文文档(七)(2)https://developer.aliyun.com/article/1559697


玩具示例

我们如何在实践中使用 shard_map 和集体通信?这些例子虽然简单,但提供了一些思路。

矩阵乘法

并行化矩阵乘法对于扩展深度学习模型至关重要,无论是用于训练还是推断。当 jax.jit 自动并行化矩阵乘法时,它可以使用几种不同的策略,这取决于矩阵大小、硬件细节和其他因素。我们如何更明确地编写一些使用 shard_map 并行化的例程?如何优化它们以获得更好的计算/通信重叠,从而提高 FLOP 利用率?

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map 
mesh = Mesh(jax.devices()[:4], ('i',))
def device_put(x, pspec):
  return jax.device_put(x, NamedSharding(mesh, pspec)) 

示例 1:all-gather 在一侧

考虑执行一个矩阵乘法,在这个过程中我们在其主(非收缩)维度上分片左侧参数(可以考虑:参数):

lhs_spec = P('i', None)
lhs = device_put(jax.random.normal(jax.random.key(0), (8, 8)), lhs_spec) 

并且我们在其收缩维度上分片右侧参数(可以考虑:激活),输出也类似分片:

rhs_spec = P('i', None)
rhs = device_put(jax.random.normal(jax.random.key(1), (8, 4)), rhs_spec) 

为了执行这个矩阵乘法,我们可以首先全收集右侧,然后对分片左侧进行本地矩阵乘法:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather(lhs_block, rhs_block):
  rhs = jax.lax.all_gather(rhs_block, 'i', tiled=True)
  return lhs_block @ rhs 
out = matmul_allgather(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

这很棒,但我们这里没有计算/通信重叠:在我们可以开始矩阵乘法之前,我们需要 all_gather 完成。这里是使用相同代码的性能分析,但在更大的示例形状上 (lhs(8192, 8192)rhs(8192, 1024)):

如果我们不是调用 all_gather,而是基本上在我们的 ppermute 实现中内联我们上面的 all_gather,那么我们可以获得计算/通信重叠,然后交错进行收集排列步骤与本地矩阵乘法的步骤:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i + 1) % size) for i in range(size)])
  B = lhs_block.shape[1] // size
  lhs_blocks = lambda i: lax.dynamic_slice_in_dim(lhs_block, i * B, B, 1)
  out_block = lhs_blocks(idx) @ rhs_block
  for i in range(1, size):
    rhs_block = shift(rhs_block)
    out_block += lhs_blocks((idx - i) % size) @ rhs_block
  return out_block 
out = matmul_allgather_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

此实现允许在通信和计算之间重叠,并且还避免在每个设备上聚合大量中间数据。但在 TPU 上,通过沿环的一个方向仅置换,仅使用一半的互连带宽。要双向置换,我们只需将块分成两半,并将每半分别发送到每个方向:

@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_allgather_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])
  B = lhs_block.shape[1] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 1)
  rhs_block_lo, rhs_block_hi = jnp.split(rhs_block, 2, axis=0)
  out_block  = lhs_blocks(idx, 0) @ rhs_block_lo
  out_block += lhs_blocks(idx, 1) @ rhs_block_hi
  for i in range(1, size):
    rhs_block_lo = shift_up(rhs_block_lo)
    rhs_block_hi = shift_dn(rhs_block_hi)
    out_block += lhs_blocks((idx - i) % size, 0) @ rhs_block_lo
    out_block += lhs_blocks((idx + i) % size, 1) @ rhs_block_hi
  return out_block 
out = matmul_allgather_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

在实践中,为了减少编译时间,我们可能会将这些内容合并到jax.lax.fori_loop中。我们可能还涉及额外的轴并行化。

示例 2:psum_scatter结果

另一个我们可以开始的分片方法是,将lhsrhs沿其收缩维度进行分片,输出再次像rhs一样进行分片:

lhs_spec = P(None, 'i')
lhs = device_put(lhs, lhs_spec)
rhs_spec = P('i', None)
rhs = device_put(rhs, rhs_spec) 

在这里,我们可以使用reduce_scatter来执行分片上的收缩求和:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter(lhs_block, rhs_block):
  out_summand = lhs_block @ rhs_block
  return jax.lax.psum_scatter(out_summand, 'i', tiled=True)
out = matmul_psumscatter(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

但散射通信必须等待整个本地矩阵乘法完成后才能开始。为了实现通信/计算重叠,我们可以内联psum_scatterppermute实现,然后将通信步骤与本地矩阵乘法交错进行:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift = partial(jax.lax.ppermute, axis_name='i',
                  perm=[(i, (i - 1) % size) for i in range(size)])
  lhs_block = lhs_block.reshape(size, -1, lhs_block.shape[1])  # split 1st axis
  out_summand = lhs_block[(idx + 1) % size] @ rhs_block
  for i in range(1, size):
    out_summand = shift(out_summand)
    out_summand += lhs_block[(idx + i + 1) % size] @ rhs_block
  return out_summand 
out = matmul_psumscatter_overlapped(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

如前例所示,为了充分利用 TPU 上的互连,我们将运行一个双向版本:

@partial(shard_map, mesh=mesh, in_specs=(lhs_spec, rhs_spec),
         out_specs=rhs_spec)
def matmul_psumscatter_overlapped_bidi(lhs_block, rhs_block):
  size = jax.lax.psum(1, 'i')
  idx = jax.lax.axis_index('i')
  shift_up = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i + 1) % size) for i in range(size)])
  shift_dn = partial(jax.lax.ppermute, axis_name='i',
                     perm=[(i, (i - 1) % size) for i in range(size)])
  B = lhs_block.shape[0] // size // 2  # half-size blocks
  lhs_blocks = lambda i, hi: lax.dynamic_slice_in_dim(lhs_block, (2*i+hi) * B, B, 0)
  out_summand_lo = lhs_blocks((idx - 1) % size, 0) @ rhs_block
  out_summand_hi = lhs_blocks((idx + 1) % size, 1) @ rhs_block
  for i in range(1, size):
    out_summand_lo = shift_up(out_summand_lo)
    out_summand_hi = shift_dn(out_summand_hi)
    out_summand_lo += lhs_blocks((idx - i - 1) % size, 0) @ rhs_block
    out_summand_hi += lhs_blocks((idx + i + 1) % size, 1) @ rhs_block
  return jnp.concatenate([out_summand_lo, out_summand_hi]) 
out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) 
True 

神经网络

我们可以使用shard_map来并行计算神经网络中的计算,可以单独使用,也可以与jax.jit中的自动分区组合使用。本节基于此玩具神经网络和随机数据提供了一些示例:

import jax
import jax.numpy as jnp
def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs
def loss(params, batch):
  inputs, targets = batch
  predictions = predict(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1)) 
def init_layer(key, n_in, n_out):
    k1, k2 = jax.random.split(key)
    W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
    b = jax.random.normal(k2, (n_out,))
    return W, b
def init(key, layer_sizes, batch_size):
    key, *keys = jax.random.split(key, len(layer_sizes))
    params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
    key, *keys = jax.random.split(key, 3)
    inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
    targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
    return params, (inputs, targets) 
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size) 

将这些示例与纯粹的“分布式数组和自动分区”文档中的自动分区示例进行比较。在这些自动分区示例中,我们无需编辑模型函数即可使用不同的并行化策略,而在shard_map中,我们经常需要这样做。

8 路批次数据并行

最简单的多设备并行策略是将输入和目标的批次在多个设备上进行分片,将参数复制到这些设备上,并并行应用模型于数据的这些分片。为了评估总损失,设备只需在末尾进行标量大小的全约和求和。(为了评估损失的梯度,设备必须在后向传播中执行参数梯度的全约和求和。)

from functools import partial
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))
# adapt the loss function to sum the losses across devices
def loss_dp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch', None), out_specs=P())
  def loss_spmd(local_batch):
    inputs, targets = local_batch
    predictions = predict(params, inputs)  # use reference 'predict`
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(batch) 

我们可以检查损失及其梯度是否与参考(基础)模型匹配:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_dp)(params, batch)) 
22.779888
22.779888 
def allclose(a, b):
  return tree_all(tree_map(partial(jnp.allclose, atol=1e-2, rtol=1e-2), a, b))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_dp))(params, batch))) 
True 

我们可以打印编译器 IR 以检查梯度计算,并验证在预期位置进行的集体全约和求和操作:在前向传播的末尾计算损失值时,以及在后向传播中计算总参数梯度时。

8 路完全分片数据并行(FSDP)

另一种策略是在设备上额外对参数进行分片,在需要完整值进行jnp.dot或偏置添加时进行全部聚集。由于我们每次只在本地设备内存中保留一个完整的参数,而不像前面的  DP 示例中在所有设备内存中保留所有参数,这样我们可以释放出大量内存,用于更大的模型或更大的批处理大小。并且由于 XLA  会重叠计算和设备间通信,所以墙钟时间不会受影响。

因此,现在我们需要在两个地方进行集体操作:模型预测函数predict需要在使用参数之前对其进行全部聚集,而与 DP 情况一样,损失函数需要对本地损失进行求和以计算总损失。

还有一项我们需要的内容:我们不希望在反向传播中存储从前向传播中完全聚集的参数。相反,我们希望在反向传播时再次聚集它们。我们可以通过使用jax.remat自定义策略(或custom_vjp)来表达这一点,尽管 XLA 通常会自动进行该重现操作。

这种通用的FSDP 方法类似于权重更新分片(WUS)ZeRO-3

# shard data batch *and params* over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P('batch')))
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    outputs = jnp.dot(inputs, W) + b
    inputs = jax.nn.relu(outputs)
  return outputs
def loss_fsdp(params, batch):
  @partial(shard_map, mesh=mesh, in_specs=P('batch'), out_specs=P())
  def loss_spmd(local_params, local_batch):
    inputs, targets = local_batch
    predictions = predict_fsdp(local_params, inputs)
    local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
    return jax.lax.pmean(local_loss, 'batch')
  return loss_spmd(params, batch) 

再次,我们可以检查损失及其梯度是否与参考模型匹配:

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp)(params, batch))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp))(params, batch))) 
22.779888
22.779888
True 

8 路张量并行性(TP)

通常我们不单独使用张量模型并行性,但单独看它可以作为并行矩阵乘法的一个良好热身。这也是在库函数中使用shard_map的一个良好示例,被调用于基于jit的大型计算中。

并行化的理念是我们将保持数据/激活沿其特征轴分片(而不是批处理轴),并且我们将类似地在输入特征轴上分片权重矩阵(和在其特征轴上的偏置)。然后,为了执行并行矩阵乘法,我们将执行本地矩阵乘法,然后进行psum_scatter以对本地结果求和并高效地分散结果的分片。

devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))
batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
def predict_tp(params, inputs):
  for W, b in params:
    outputs = gemm_tp(inputs, W, b)
    inputs = jax.nn.relu(outputs)
  return outputs
@partial(shard_map, mesh=mesh,
         in_specs=(P(None, 'feats'), P('feats', None), P('feats')),
         out_specs=P(None, 'feats'))
def gemm_tp(inputs, W, b):
  block_result = jnp.dot(inputs, W)
  return jax.lax.psum_scatter(block_result, 'feats',
                              scatter_dimension=1, tiled=True) + b
def loss_tp(params, batch):
  inputs, targets = batch
  predictions = predict_tp(params, inputs)
  return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1))  # NOTE psum! 

FSDP + TP,在顶层使用shard_map

我们可以将这些策略组合在一起,使用多轴并行性。

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))
batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
# mostly same as previous predict_fsdp definition, except we call gemm_tp
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
def predict_fsdp_tp(params_frag, inputs):
  for W_frag, b_frag in params_frag:
    W = jax.lax.all_gather(W_frag, 'batch', tiled=True)
    b = jax.lax.all_gather(b_frag, 'batch', tiled=True)
    block_result = jnp.dot(inputs, W)
    outputs = jax.lax.psum_scatter(block_result, 'feats',
                                   scatter_dimension=1, tiled=True) + b
    inputs = jax.nn.relu(outputs)
  return outputs
@partial(shard_map, mesh=mesh,
         in_specs=(P(('feats', 'batch')), P('batch', 'feats')),
         out_specs=P())
def loss_fsdp_tp(local_params, local_batch):
  inputs, targets = local_batch
  predictions = predict_fsdp_tp(local_params, inputs)
  sq_err = jax.lax.psum(jnp.sum((predictions - targets)**2, axis=-1), 'feats')
  return jax.lax.pmean(jnp.mean(sq_err), 'batch') 

注意我们必须进行两次集体归约的方式:一次是在'feats'上,另一次是在'batch'上。在纯 TP 示例中,我们没有显式写出'feats'归约,因为我们仅在gemm_tp内部使用了shard_map;在调用loss_tp时,编译器会自动将我们对jnp.sum的使用转换为根据predict_tp返回的分片结果执行所需的psum

print(jax.jit(loss)(params, batch))
print(jax.jit(loss_fsdp_tp)(params_, batch_))
print(allclose(jax.jit(jax.grad(loss))(params, batch),
               jax.jit(jax.grad(loss_fsdp_tp))(params, batch))) 
22.779886
22.779886
True 

SPMD 管道并行性(PP)

通过流水线并行,我们的目标是并行评估网络中不同深度层的层。例如,一个设备可能计算第一层的应用,而另一个设备计算第二层的应用;当它们完成时,第一个设备将其结果传递给第二个设备,而第二个设备将其结果传递给负责第三层的设备,这个过程重复进行。一般来说,流水线阶段的数量可能与层的数量不同,因为每个阶段可能负责多个层。

使用 SPMD 流水线,我们利用网络中大多数层应用计算的事实,只是参数值不同。特别是,我们可以堆叠除了第一层和最后一层之外的所有参数,然后使用shard_map将这些层参数块映射到管道阶段。然后我们使用jax.lax.ppermute集合来沿并行管道向下移动数据。

这种特定的流水线策略本质上是GPipe 策略。有几种变体以及相当不同的策略,哪一种适合取决于各阶段之间的网络速度和批量大小。但是在本教程中,我们将专注于只有一种策略。

首先,我们选择一些流水线参数:

L = len(params) - 2        # num layers, excluding first and last
N = batch_size             # batch size
F = params[0][0].shape[1]  # num features
# choose some pipeline parameters
S = 2      # number of stages
B = 8      # size of each microbatch
assert L % S == 0, "S (number of stages) must divide L (number of inner layers)"
# compute some useful quantities
M, ragged = divmod(N, B)  # M is number of microbatches
assert not ragged, "B (size of each microbatch) must divide total batch size"
K, ragged = divmod(M, S)  # K is microbatches per stage
assert not ragged, "S (number of stages) must divide number of microbatches"
print(f'{S} stages, {L  //  S} layer(s) per stage, {L} pipelined layers total')
print(f'{B} examples per microbatch, {M} microbatches total') 
2 stages, 2 layer(s) per stage, 4 pipelined layers total
8 examples per microbatch, 4 microbatches total 
mesh = Mesh(jax.devices()[:S], ('stages',))
def predict_pp(params, inputs):
  (W_first, b_first), inner_params, (W_last, b_last) = params
  inputs = jax.nn.relu(jnp.dot(inputs, W_first) + b_first)
  inputs = spmd_pipeline(lambda Wb, x: jax.nn.relu(x @ Wb[0] + Wb[1]),
                        inner_params, inputs)
  outputs = jnp.dot(inputs, W_last) + b_last
  return outputs
@partial(shard_map, mesh=mesh, in_specs=((P(), P('stages'), P()), P('stages')),
         out_specs=P())
def loss_pp(params, batch):
  inputs, targets = batch
  predictions = predict_pp(params, inputs.reshape(K, B, -1)).reshape(K * B, -1)
  local_loss = jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
  return jax.lax.pmean(local_loss, 'stages') 
def spmd_pipeline(fn, stage_params, inputs):
  stage = jax.lax.axis_index('stages')
  outputs = jnp.zeros_like(inputs) * jnp.nan
  state = jnp.zeros((L // S, B, F)) * jnp.nan
  for i in range(M+L-1):
    state = state.at[0].set(jnp.where(stage == 0, inputs[i % K], state[0]))
    state = jax.vmap(fn)(stage_params, state)
    outputs = outputs.at[(i-L+1) % K].set(jnp.where(stage == S-1, state[-1], outputs[(i-L+1) % K]))
    state, inputs, outputs = shift(i, state, inputs, outputs)
  outputs = jax.lax.ppermute(outputs, 'stages', [(i, (i+1) % S) for i in range(S)])
  return outputs
def shift(i, state, inputs, outputs):
  sh = lambda x, d: jax.lax.ppermute(x, 'stages', [(i, (i+d) % S) for i in range(S)])
  state = jnp.roll(state, +1, axis=0).at[0].set(sh(state[-1], +1))
  if (i % K) == (-1 % K):
    inputs = sh(inputs, +1)
  if ((i-L+1) % K) == (-1 % K):
    outputs = sh(outputs, +1)
  return state, inputs, outputs 
first_params, *inner_params, last_params = params
Ws, bs = zip(*inner_params)
params_stacked = jnp.stack(Ws), jnp.stack(bs)
first_params = jax.device_put(first_params, NamedSharding(mesh, P()))
params_stacked = jax.device_put(params_stacked, NamedSharding(mesh, P('stages')))
last_params = jax.device_put(last_params, NamedSharding(mesh, P()))
params_ = first_params, params_stacked, last_params
batch_ = jax.device_put(batch, NamedSharding(mesh, P('stages'))) 
print(jax.jit(loss)(params, batch))
print(jax.jit(loss_pp)(params_, batch_)) 
22.779886
22.779884 
_ = jax.jit(jax.grad(loss_pp))(params_, batch_)   # don't crash 


JAX 中文文档(七)(4)https://developer.aliyun.com/article/1559700

相关文章
|
9天前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
11 1
|
9天前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
15 0
JAX 中文文档(六)(1)
|
9天前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
9 0
|
9天前
|
测试技术 TensorFlow 算法框架/工具
JAX 中文文档(五)(2)
JAX 中文文档(五)
10 0
|
9天前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
11 0
|
9天前
|
机器学习/深度学习 算法 异构计算
JAX 中文文档(七)(2)
JAX 中文文档(七)
9 0
|
9天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
15 0
|
9天前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
10 0
|
9天前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
12 0
|
9天前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
10 0