JAX 中文文档(九)(5)

简介: JAX 中文文档(九)

JAX 中文文档(九)(4)https://developer.aliyun.com/article/1559676


让我们进行测试

per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
    jax.random.PRNGKey(0),
    shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
    dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16) 

测试前向函数

out = rms_norm_fwd(x, weight) 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)
...
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented 

抽象评估

上述测试失败,报错信息为NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented。为什么测试失败?这是什么意思?

作为执行的一部分,JAX 执行抽象评估。由于 JAX 对新原语没有任何了解,因此不知道如何计算输出形状和输出数据类型,因此无法进行这些操作的抽象评估。

我们需要为每个原语的抽象评估提供一个函数。这些抽象评估函数计算输出的形状和数据类型,但不计算操作的实际值。

这些函数将传递给.def_abstract_eval方法,以便与相应的原语进行注册。

更多关于抽象评估的信息,请参见How JAX primitives work

from functools import reduce
from operator import mul
from jax.core import ShapedArray
def _rms_norm_fwd_abstract(x, weight, eps):
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    iv_dtype = dtypes.canonicalize_dtype(x.dtype)
    if iv_dtype in [jnp.float16, jnp.bfloat16]:
        iv_dtype = jnp.float32
    n2 = reduce(mul, weight.shape)
    n1 = reduce(mul, x.shape) // n2
    return (
        ShapedArray(x.shape, w_dtype, named_shape=x.named_shape),  # output
        ShapedArray((n1,), iv_dtype, named_shape=x.named_shape),  # invvar
    )
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
    iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    x_dtype = dtypes.canonicalize_dtype(x.dtype)
    n2 = reduce(lambda x, y: x * y, weight.shape)
    n1 = reduce(lambda x, y: x * y, x.shape) // n2
    part_grad_shape = (16, n2)
    assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
    assert grad_output.shape == x.shape
    assert invvar.shape == (n1,)
    assert (
        iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
    )
    assert grad_output.named_shape == x.named_shape
    weight_named_shape = (
        weight_named_shape if weight.named_shape else x.named_shape
    )
    return (
        ShapedArray(
            x.shape, x_dtype, named_shape=x.named_shape
        ),  # grad input
        ShapedArray(
            weight.shape, w_dtype, named_shape=weight_named_shape
        ),  # grad weight
        ShapedArray(
            part_grad_shape, iv_dtype, named_shape=weight_named_shape
        ),  # part grad
    )
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract) 

让我们再次进行测试

测试前向函数

out = rms_norm_fwd(x, weight) 

测试反向函数

现在让我们使用jax.gradjtu.check_grads测试反向操作。

def loss(x, weight):
    predictions = rms_norm_fwd(x, weight)
    return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1) 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [8], line 7
      3     return -jnp.mean(predictions**2)
      6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)
...
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented 

差分规则

反向操作以 NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented 错误失败。这意味着,尽管我们定义了 rms_norm_fwdrms_norm_bwd,但 JAX 不知道它们之间的关系。

我们可以使用 jax.custom_vjp 及其约定,教给 JAX rms_norm_bwdrms_norm_fwd 的反向操作。作为第一步,我们需要完善 rms_norm_fwdrms_norm_bwd 的定义。

# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
#     output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
#     return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)
# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
#     grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
#         g, invvar, x, weight, eps=eps
#     )
#     return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight 

rms_norm_fwd 现在返回额外的输出 (invvar, x, weight) 作为剩余数据,而 rms_norm_bwd 接受 epsresg 作为参数。

通过 jax.custom_vjp 建立 rms_norm_fwdrms_norm_bwd 之间的关系后,JAX 将确保从 rms_norm_fwd 中传递的剩余数据作为反向操作的 res 传递给 rms_norm_bwd。对于像 eps 这样的不可微参数,JAX 确保它们在剩余数据之前传递给反向操作。这就是为什么 epsrms_norm_bwd 的参数列表中位于 res 之前。

现在 rms_norm_fwd 返回了不需要用于简单 RMS 标准化操作的剩余数据,我们在其周围定义了一个包装器 rms_norm,它简单地调用 rms_norm_fwd 并仅返回 output。请注意,rms_norm 使用 @partial(jax.custom_vjp, nondiff_argnums=(2,)) 进行了注释,并且我们将 rms_norm_fwdrms_norm_bwd 传递给 rms_norm.defvjp。这教会了 JAX,在对 rms_norm 进行微分时,使用 rms_norm_fwd 进行前向操作,使用 rms_norm_bwd 进行反向操作。

有关使用 jax.custom_vjp 定义 JAX 可转换 Python 函数的自定义导数规则,请参阅自定义导数规则

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
    output, _ = rms_norm_fwd(x, weight, eps=eps)
    return output
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) 

经过我们的改进,反向操作测试与修改一起正常运行:loss 现在调用 rms_norm 而不是 rms_norm_fwd

def loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1) 

让我们在多个设备上进行测试。

我们正在使用 jax.experimental.pjit.pjit 在多个设备上进行并行执行,并在单个设备上的顺序执行中生成参考值。

测试前向函数。

让我们先在多个设备上测试前向操作。我们创建了一个简单的 1D 网格,并在所有设备上分片 x

from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
    rms_norm,
    # Shard x by batch dimension and replicate weight on all devices.
    in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
    # Shard the output by batch dimension.
    out_shardings=PartitionSpec("x", None, None),
)
with mesh:
    print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
    out = pjitted(x, weight)
jnp.allclose(ref, out, atol=1e-5, rtol=1e-5) 
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}
%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
  %param_1 = bf16[32,512,512]{2,1,0} parameter(0)
  %param_1.3 = u32[] parameter(1)
  %convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
  %param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
  %all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
  %custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
  %get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
} 
True 

对于前向操作,值已经计算正确,然而生成的 HLO 模块显示一个 all-gather 操作来在所有设备上复制 x,导致大量的通信开销。

由于 XLA 对于自定义函数不具备足够的知识来分片输入张量,它决定在进行自定义调用之前将它们复制以生成正确的值。

为了避免这种重复,我们可以:

  • custom_partitioning:使其表现得像所有本机 JAX 操作一样(但更复杂)。
  • 使用手动分片

此示例演示了使用 custom_partitioning 的用法。

使用 custom_partitioning 分片向前函数

首先创建一个辅助函数来帮助所有需要的 JAX/XLA 回调注册。

def register_primitive(cls):
  """
 register jax primitive
 The order of calls. Each operation is composed of two primitives: Inner and Outer.
 Inner, only the basic to wrap the custom_call itself.
 - impl to XLA custom_call in C.
 - abstract to know the static shapes
 - lower to StableHLO XLA custom_call.
 Outer, mostly all the rest:
 - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
 - abstract: same
 - lower to StableHLO custom_p. (XLA will call the python callback from it)
 - custom_p
 - vmap: could be added here.
 VJP is based on Outer, but not handled in this function.
 """
    def name_of_wrapper_p():
        return cls.name + "_wrapper"
    inner_p = core.Primitive(cls.name)
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p
    outer_p = core.Primitive(name_of_wrapper_p())
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
... 

我们定义了两个 JAX 原语,一个内部原语映射到我们想要在 JAX 中封装的真实内核。还有一个外部原语,将与自定义分区注册一起使用,并用于梯度。(如果您实现支持 vmat 的接口,它也将位于外部原语中)。

JAX custom_partitioning 实现是 XLA 在 XLA 分片逻辑期间从 XLA 到 Python 的回调。XLA  分片分为两个阶段:分片传播阶段和分区阶段。传播阶段是 XLA 规划要创建的分片的阶段。分区阶段创建分片图。为了让 XLA  能够分片我们的自定义操作,它需要我们定义两个额外的函数:infer_sharding_from_operands() 和  partition()。它们分别在第一阶段和第二阶段中使用。

infer_sharding_from_operands() 函数必须做其名称所述的事情:从输入分片推断输出分片。

partition() 函数将执行几个操作:

  • 告诉预期将有哪些输入分片。如有必要,XLA 将进行重新分片。
  • 告诉输出分片的最终版本。
  • 给出一个函数,将从分片输入创建新指令。

查看代码注释以获取更多解释:

class RmsNormFwdClass:
    name = "rms_forward_affine_mixed_dtype"
    multiple_results = True
    impl_static_args = (2,)    # eps
    inner_primitive = None
    outer_primitive = None
    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims of all inputs to be replicated except the
        # first dim of x that will be kept as is.
        # This is because the implementaion can only be sharded on the batch dimensions.
        x_spec = arg_infos[0].sharding.spec
        # None mean that we replicate on that dimension.
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        return (output_sharding, invvar_sharding)
    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        x_spec = arg_infos[0].sharding.spec
        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        output_shardings = (arg_shardings[0], invvar_sharding)
        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        impl = partial(RmsNormFwdClass.impl, eps=eps)
        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormFwdClass) 

接下来我们定义 RMSNorm 后向传递的原语

使用 custom_partitioning 分片向后函数

class RmsNormBwdClass:
    name = "rms_norm_bwd"
    multiple_results = True
    impl_static_args = (4,)    # eps
    inner_primitive = None
    outer_primitive = None
    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims to be replicated except the batch dimension.
        x_spec = x_info.sharding.spec
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        return (output_sharding, invvar_sharding, output_sharding, )
    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        # Also force gx, x and invvar to have the same batch sharding/replication.
        x_spec = x_info.sharding.spec
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0],)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        def impl(g, invvar, x, weight):
            grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
                g, invvar, x, weight, eps=eps
            )
            # We need to sum the weight gradient from all partition.
            global_weight = grad_weight
            if x_spec[0]:
                global_weight = jax.lax.psum(grad_weight, x_spec[0])
            return grad_input, global_weight, part_grad
        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormBwdClass) 

通过与以前相同的自定义 _vjp 规则建立前向和后向原语的管道:

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def custom_p_rms_norm(x, weight, eps=1e-05):
    output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
    return output
def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)
def custom_p_rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
        g, invvar, x, weight, eps=eps)
    return grad_input, grad_weight
custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd) 

有了这些,我们完全定义了我们的自定义 RMS 规范原语与自定义分区。为了检查正确性,我们定义了以下损失函数:ref_loss 是要与之比较的参考值,而 custom_p_loss 使用了我们新实现的实现了自定义分区的原语。

def ref_loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)
ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight)
def custom_p_loss(x, weight):
    predictions = custom_p_rms_norm(x, weight)
    return -jnp.mean(predictions**2) 

检查正确性

with Mesh(jax.local_devices(), ("x",)):
    def run_and_verify(loss):
        pjitted = pjit(
            jax.grad(loss, argnums=(0, 1)),
            # Shard x by batch dimension and replicate weight on all devices.
            in_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
            # Shard the output by batch dimension and replicate weight grad on all devices.
            out_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
        )
        hlo = pjitted.lower(x, weight).compile().as_text()
        out = pjitted(x, weight)
        print(hlo)
        assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
        if "all-gather-start" in hlo:
            print("NOT OPTIMIZED, ALL_GATHER in the graph!")
        return out
    custom_p_out = run_and_verify(custom_p_loss)
for r, o in zip(ref_out, custom_p_out):
    print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6)) 
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}
%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
  %param_0 = f16[4,512,512]{2,1,0} parameter(0)
  %constant_4_1 = f16[] constant(-4.7684e-07)
  %broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}
%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
  %Arg_1.11.0 = f16[] parameter(1)
  %Arg_0.10.0 = f16[] parameter(0)
  ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}
ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
  %param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
  %param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
  %custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
  %get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  %get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
  %get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
  %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
} 
True
True 

现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。

让我们把它放在一起

使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:

gpu_ops 代码列表

gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu
  propagate_user_sharding=None infer_sharding_from_operands=  decode_shardings=True in_tree=PyTreeDef((*, *, *, )) out_tree=PyTreeDef((, *, )) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483}
 %all-reduce-start = f16[512,512]{1,0}  all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0),  channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true,  to_apply=%region_0.9._custom_call_lowering_rule,  metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=  propagate_user_sharding=None infer_sharding_from_operands=  decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, ))  static_args=[1e-05]]"  source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py”  source_line=483},  backend_config={“operation_queue_id”:“0”,“wait_on_operation_queues”:[],“collective_backend_config”:{“is_sync”:true,“no_parallel_custom_call”:false}}
 %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0}  %all-reduce-start),  metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=  propagate_user_sharding=None infer_sharding_from_operands=  decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, )) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483}
 %get-tuple-element.12.0 = f16[4,512,512]{2,1,0}  get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0},  f32[16,262144]{1,0}) %custom-call.5.0), index=0,  metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=  propagate_user_sharding=None infer_sharding_from_operands=  decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, *)) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483}
 ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0})  tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0}  %all-reduce-done)
 }
```py
True
True 

现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。

让我们把它放在一起

使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:

gpu_ops 代码列表

gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu
相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
2天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
10 1
|
2天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
11 1
|
2天前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
10 0
|
2天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
9 0
|
2天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
2天前
|
编译器 异构计算 Python
JAX 中文文档(四)(2)
JAX 中文文档(四)
7 0
|
2天前
|
机器学习/深度学习 算法 异构计算
JAX 中文文档(七)(2)
JAX 中文文档(七)
8 0
|
2天前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
8 0
|
2天前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
8 0
|
2天前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
9 0