JAX 中文文档(九)(4)https://developer.aliyun.com/article/1559676
让我们进行测试
per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( jax.random.PRNGKey(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) norm_shape = x.shape[-2:] weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
测试前向函数
out = rms_norm_fwd(x, weight)
--------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) Cell In [5], line 1 ----> 1 out = rms_norm_fwd(x, weight) ... NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
抽象评估
上述测试失败,报错信息为NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
。为什么测试失败?这是什么意思?
作为执行的一部分,JAX 执行抽象评估。由于 JAX 对新原语没有任何了解,因此不知道如何计算输出形状和输出数据类型,因此无法进行这些操作的抽象评估。
我们需要为每个原语的抽象评估提供一个函数。这些抽象评估函数计算输出的形状和数据类型,但不计算操作的实际值。
这些函数将传递给.def_abstract_eval
方法,以便与相应的原语进行注册。
更多关于抽象评估的信息,请参见How JAX primitives work。
from functools import reduce from operator import mul from jax.core import ShapedArray def _rms_norm_fwd_abstract(x, weight, eps): w_dtype = dtypes.canonicalize_dtype(weight.dtype) iv_dtype = dtypes.canonicalize_dtype(x.dtype) if iv_dtype in [jnp.float16, jnp.bfloat16]: iv_dtype = jnp.float32 n2 = reduce(mul, weight.shape) n1 = reduce(mul, x.shape) // n2 return ( ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar ) _rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract) def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps): iv_dtype = dtypes.canonicalize_dtype(invvar.dtype) w_dtype = dtypes.canonicalize_dtype(weight.dtype) x_dtype = dtypes.canonicalize_dtype(x.dtype) n2 = reduce(lambda x, y: x * y, weight.shape) n1 = reduce(lambda x, y: x * y, x.shape) // n2 part_grad_shape = (16, n2) assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype assert grad_output.shape == x.shape assert invvar.shape == (n1,) assert ( iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype ) assert grad_output.named_shape == x.named_shape weight_named_shape = ( weight_named_shape if weight.named_shape else x.named_shape ) return ( ShapedArray( x.shape, x_dtype, named_shape=x.named_shape ), # grad input ShapedArray( weight.shape, w_dtype, named_shape=weight_named_shape ), # grad weight ShapedArray( part_grad_shape, iv_dtype, named_shape=weight_named_shape ), # part grad ) _rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
让我们再次进行测试
测试前向函数
out = rms_norm_fwd(x, weight)
测试反向函数
现在让我们使用jax.grad
和jtu.check_grads
测试反向操作。
def loss(x, weight): predictions = rms_norm_fwd(x, weight) return -jnp.mean(predictions**2) loss_grad = jax.grad(loss) out = loss_grad(x, weight) jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
--------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) Cell In [8], line 7 3 return -jnp.mean(predictions**2) 6 loss_grad = jax.grad(loss) ----> 7 out = loss_grad(x, weight) ... NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
差分规则
反向操作以 NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
错误失败。这意味着,尽管我们定义了 rms_norm_fwd
和 rms_norm_bwd
,但 JAX 不知道它们之间的关系。
我们可以使用 jax.custom_vjp
及其约定,教给 JAX rms_norm_bwd
是 rms_norm_fwd
的反向操作。作为第一步,我们需要完善 rms_norm_fwd
和 rms_norm_bwd
的定义。
# rms_norm_fwd was previously defined as # # def rms_norm_fwd(x, weight, eps=1e-05): # output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps) # return output # def rms_norm_fwd(x, weight, eps=1e-05): output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps) return output, (invvar, x, weight) # rms_norm_bwd was previously defined as # # def rms_norm_bwd(g, invvar, x, weight, eps): # grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( # g, invvar, x, weight, eps=eps # ) # return grad_input, grad_weight # def rms_norm_bwd(eps, res, g): invvar, x, weight = res grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( g, invvar, x, weight, eps=eps ) return grad_input, grad_weight
rms_norm_fwd
现在返回额外的输出 (invvar, x, weight)
作为剩余数据,而 rms_norm_bwd
接受 eps
、res
和 g
作为参数。
通过 jax.custom_vjp
建立 rms_norm_fwd
和 rms_norm_bwd
之间的关系后,JAX 将确保从 rms_norm_fwd
中传递的剩余数据作为反向操作的 res
传递给 rms_norm_bwd
。对于像 eps
这样的不可微参数,JAX 确保它们在剩余数据之前传递给反向操作。这就是为什么 eps
在 rms_norm_bwd
的参数列表中位于 res
之前。
现在 rms_norm_fwd
返回了不需要用于简单 RMS 标准化操作的剩余数据,我们在其周围定义了一个包装器 rms_norm
,它简单地调用 rms_norm_fwd
并仅返回 output
。请注意,rms_norm
使用 @partial(jax.custom_vjp, nondiff_argnums=(2,))
进行了注释,并且我们将 rms_norm_fwd
和 rms_norm_bwd
传递给 rms_norm.defvjp
。这教会了 JAX,在对 rms_norm
进行微分时,使用 rms_norm_fwd
进行前向操作,使用 rms_norm_bwd
进行反向操作。
有关使用 jax.custom_vjp
定义 JAX 可转换 Python 函数的自定义导数规则,请参阅自定义导数规则。
@partial(jax.custom_vjp, nondiff_argnums=(2,)) def rms_norm(x, weight, eps=1e-05): output, _ = rms_norm_fwd(x, weight, eps=eps) return output rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
经过我们的改进,反向操作测试与修改一起正常运行:loss
现在调用 rms_norm
而不是 rms_norm_fwd
。
def loss(x, weight): predictions = rms_norm(x, weight) return -jnp.mean(predictions**2) loss_grad = jax.grad(loss) out = loss_grad(x, weight) jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
让我们在多个设备上进行测试。
我们正在使用 jax.experimental.pjit.pjit
在多个设备上进行并行执行,并在单个设备上的顺序执行中生成参考值。
测试前向函数。
让我们先在多个设备上测试前向操作。我们创建了一个简单的 1D 网格,并在所有设备上分片 x
。
from jax.sharding import Mesh, PartitionSpec from jax.experimental.pjit import pjit mesh = Mesh(jax.local_devices(), ("x",)) ref = rms_norm(x, weight) pjitted = pjit( rms_norm, # Shard x by batch dimension and replicate weight on all devices. in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)), # Shard the output by batch dimension. out_shardings=PartitionSpec("x", None, None), ) with mesh: print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string()) out = pjitted(x, weight) jnp.allclose(ref, out, atol=1e-5, rtol=1e-5)
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}} %fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] { %param_1 = bf16[32,512,512]{2,1,0} parameter(0) %param_1.3 = u32[] parameter(1) %convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} %constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} %multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} %constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} } ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] { %param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7} %all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} %param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated} %custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000" %get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} %partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8} }
True
对于前向操作,值已经计算正确,然而生成的 HLO 模块显示一个 all-gather
操作来在所有设备上复制 x
,导致大量的通信开销。
由于 XLA 对于自定义函数不具备足够的知识来分片输入张量,它决定在进行自定义调用之前将它们复制以生成正确的值。
为了避免这种重复,我们可以:
- custom_partitioning:使其表现得像所有本机 JAX 操作一样(但更复杂)。
- 使用手动分片
此示例演示了使用 custom_partitioning 的用法。
使用 custom_partitioning 分片向前函数
首先创建一个辅助函数来帮助所有需要的 JAX/XLA 回调注册。
def register_primitive(cls): """ register jax primitive The order of calls. Each operation is composed of two primitives: Inner and Outer. Inner, only the basic to wrap the custom_call itself. - impl to XLA custom_call in C. - abstract to know the static shapes - lower to StableHLO XLA custom_call. Outer, mostly all the rest: - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind. - abstract: same - lower to StableHLO custom_p. (XLA will call the python callback from it) - custom_p - vmap: could be added here. VJP is based on Outer, but not handled in this function. """ def name_of_wrapper_p(): return cls.name + "_wrapper" inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) mlir.register_lowering(inner_p, cls.lowering, platform='cuda') cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results outer_p.def_impl(cls.impl) outer_p.def_abstract_eval(cls.abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition) mlir.register_lowering(outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)) cls.outer_primitive = outer_p ...
我们定义了两个 JAX 原语,一个内部原语映射到我们想要在 JAX 中封装的真实内核。还有一个外部原语,将与自定义分区注册一起使用,并用于梯度。(如果您实现支持 vmat 的接口,它也将位于外部原语中)。
JAX custom_partitioning 实现是 XLA 在 XLA 分片逻辑期间从 XLA 到 Python 的回调。XLA 分片分为两个阶段:分片传播阶段和分区阶段。传播阶段是 XLA 规划要创建的分片的阶段。分区阶段创建分片图。为了让 XLA 能够分片我们的自定义操作,它需要我们定义两个额外的函数:infer_sharding_from_operands() 和 partition()。它们分别在第一阶段和第二阶段中使用。
infer_sharding_from_operands() 函数必须做其名称所述的事情:从输入分片推断输出分片。
partition() 函数将执行几个操作:
- 告诉预期将有哪些输入分片。如有必要,XLA 将进行重新分片。
- 告诉输出分片的最终版本。
- 给出一个函数,将从分片输入创建新指令。
查看代码注释以获取更多解释:
class RmsNormFwdClass: name = "rms_forward_affine_mixed_dtype" multiple_results = True impl_static_args = (2,) # eps inner_primitive = None outer_primitive = None @staticmethod def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], result_infos : Tuple[jax._src.core.ShapedArray]): del eps, result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 assert len(weight_info.shape) == 2 # partition() will force all dims of all inputs to be replicated except the # first dim of x that will be kept as is. # This is because the implementaion can only be sharded on the batch dimensions. x_spec = arg_infos[0].sharding.spec # None mean that we replicate on that dimension. output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) return (output_sharding, invvar_sharding) @staticmethod def partition(eps : float, mesh : jax.sharding.Mesh, arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): del result_infos # Not needed for this example. x_info, weight_info = arg_infos assert len(x_info.shape) == 3 assert len(weight_info.shape) == 2 x_spec = arg_infos[0].sharding.spec # We only support sharding on the batch dimensions. # Force sharding on all others dimensions with None. arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), NamedSharding(mesh, PartitionSpec(None, None))) invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) # Sharded_impl only accepts positional arugments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) return mesh, impl, output_shardings, arg_shardings register_primitive(RmsNormFwdClass)
接下来我们定义 RMSNorm 后向传递的原语
使用 custom_partitioning 分片向后函数
class RmsNormBwdClass: name = "rms_norm_bwd" multiple_results = True impl_static_args = (4,) # eps inner_primitive = None outer_primitive = None @staticmethod def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh, arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], result_infos : Tuple[jax._src.core.ShapedArray]): del eps, result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 assert len(invvar_info.shape) == 1 assert len(x_info.shape) == 3 assert len(weight_info.shape) == 2 # partition() will force all dims to be replicated except the batch dimension. x_spec = x_info.sharding.spec output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) return (output_sharding, invvar_sharding, output_sharding, ) @staticmethod def partition(eps : float, mesh : jax.sharding.Mesh, arg_infos : Tuple[jax._src.api.ShapeDtypeStruct], result_infos : Tuple[jax._src.api.ShapeDtypeStruct]): del result_infos # Not needed for this example. g_info, invvar_info, x_info, weight_info = arg_infos assert len(g_info.shape) == 3 assert len(invvar_info.shape) == 1 assert len(x_info.shape) == 3 assert len(weight_info.shape) == 2 # We only support sharding on the batch dimensions. # Force sharding on all others dimensions with None. # Also force gx, x and invvar to have the same batch sharding/replication. x_spec = x_info.sharding.spec arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), NamedSharding(mesh, PartitionSpec(x_spec[0],)), NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)), NamedSharding(mesh, PartitionSpec(None, None))) output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)) invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None)) output_shardings = (output_sharding, invvar_sharding, invvar_sharding) # Sharded_impl only accepts positional arugments # And they should be Jax traceable variables def impl(g, invvar, x, weight): grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( g, invvar, x, weight, eps=eps ) # We need to sum the weight gradient from all partition. global_weight = grad_weight if x_spec[0]: global_weight = jax.lax.psum(grad_weight, x_spec[0]) return grad_input, global_weight, part_grad return mesh, impl, output_shardings, arg_shardings register_primitive(RmsNormBwdClass)
通过与以前相同的自定义 _vjp 规则建立前向和后向原语的管道:
@partial(jax.custom_vjp, nondiff_argnums=(2,)) def custom_p_rms_norm(x, weight, eps=1e-05): output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps) return output def custom_p_rms_norm_fwd(x, weight, eps=1e-05): output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps) return output, (invvar, x, weight) def custom_p_rms_norm_bwd(eps, res, g): invvar, x, weight = res grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind( g, invvar, x, weight, eps=eps) return grad_input, grad_weight custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd)
有了这些,我们完全定义了我们的自定义 RMS 规范原语与自定义分区。为了检查正确性,我们定义了以下损失函数:ref_loss 是要与之比较的参考值,而 custom_p_loss 使用了我们新实现的实现了自定义分区的原语。
def ref_loss(x, weight): predictions = rms_norm(x, weight) return -jnp.mean(predictions**2) ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight) def custom_p_loss(x, weight): predictions = custom_p_rms_norm(x, weight) return -jnp.mean(predictions**2)
检查正确性
with Mesh(jax.local_devices(), ("x",)): def run_and_verify(loss): pjitted = pjit( jax.grad(loss, argnums=(0, 1)), # Shard x by batch dimension and replicate weight on all devices. in_shardings=( PartitionSpec("x", None, None), PartitionSpec(None, None), ), # Shard the output by batch dimension and replicate weight grad on all devices. out_shardings=( PartitionSpec("x", None, None), PartitionSpec(None, None), ), ) hlo = pjitted.lower(x, weight).compile().as_text() out = pjitted(x, weight) print(hlo) assert "all-reduce-done" in hlo, "The gradient will produce wrong value!" if "all-gather-start" in hlo: print("NOT OPTIMIZED, ALL_GATHER in the graph!") return out custom_p_out = run_and_verify(custom_p_loss) for r, o in zip(ref_out, custom_p_out): print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"} %fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] { %param_0 = f16[4,512,512]{2,1,0} parameter(0) %constant_4_1 = f16[] constant(-4.7684e-07) %broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} } %region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] { %Arg_1.11.0 = f16[] parameter(1) %Arg_0.10.0 = f16[] parameter(0) ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433} } ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) { %param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated} %param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]} %custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000" %get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440} %loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484} %get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440} %custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000" %get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}} %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483} ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done) }
True True
现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。
让我们把它放在一起
使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:
gpu_ops
代码列表
gpu_ops/kernel_helpers.h gpu_ops/kernels.h gpu_ops/pybind11_kernel_helpers.h gpu_ops/gpu_ops.cpp gpu_ops/rms_norm_kernels.cu propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((*, *, *, )) out_tree=PyTreeDef((, *, )) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483} %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, )) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483}, backend_config={“operation_queue_id”:“0”,“wait_on_operation_queues”:[],“collective_backend_config”:{“is_sync”:true,“no_parallel_custom_call”:false}} %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, )) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483} %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition= propagate_user_sharding=None infer_sharding_from_operands= decode_shardings=True in_tree=PyTreeDef((, *, *, )) out_tree=PyTreeDef((, *, *)) static_args=[1e-05]]" source_file=“/opt/jax/docs/Custom_Operation_for_GPUs.py” source_line=483} ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done) }
```py True True
现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。
让我们把它放在一起
使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:
gpu_ops
代码列表
gpu_ops/kernel_helpers.h gpu_ops/kernels.h gpu_ops/pybind11_kernel_helpers.h gpu_ops/gpu_ops.cpp gpu_ops/rms_norm_kernels.cu