JAX 中文文档(九)(3)https://developer.aliyun.com/article/1559675
在 JAX 中编写自定义 Jaxpr 解释器
原文:
jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
[外链图片转存中…(img-VdY1f7HH-1718950445125)]
JAX 提供了几个可组合的函数转换(jit
,grad
,vmap
等),可以编写简洁且加速的代码。
这里我们展示了如何通过编写自定义 Jaxpr 解释器来向系统添加自己的函数转换。而且我们将自动获得与所有其他转换的可组合性。
此示例使用了内部 JAX API,可能随时会中断。任何不在API 文档中的内容都应视为内部内容。
import numpy as np import jax import jax.numpy as jnp from jax import jit, grad, vmap from jax import random
JAX 在做什么?
JAX 为数值计算提供了类似 NumPy 的 API,可以直接使用,但 JAX 真正的强大之处在于可组合的函数转换。例如jit
函数转换接受一个函数并返回一个语义上相同的函数,但由 XLA 进行惰性编译以加速器。
x = random.normal(random.key(0), (5000, 5000)) def f(w, b, x): return jnp.tanh(jnp.dot(x, w) + b) fast_f = jit(f)
当我们调用fast_f
时,会发生什么?JAX 会追踪函数并构建一个 XLA 计算图。然后将图进行即时编译(JIT)并执行。其他转换类似,它们首先会追踪函数并以某种方式处理输出追踪。要了解更多关于 JAX 追踪机制的信息,您可以参考 README 中的“How it works”部分。
Jaxpr 追踪器
Jax 中一个特别重要的追踪器是 Jaxpr 追踪器,它将操作记录到一个 Jaxpr(Jax 表达式)中。Jaxpr 是一种数据结构,可以像小型函数式编程语言一样进行评估,因此 Jaxprs 是函数转换的有用中间表示。
要首次查看 Jaxprs,可以考虑make_jaxpr
转换。make_jaxpr
本质上是一个“漂亮打印”转换:它将一个函数转换为一个函数,给定示例参数,生成其计算的 Jaxpr 表示。make_jaxpr
对于调试和内省非常有用。让我们使用它来查看一些示例 Jaxprs 的结构。
def examine_jaxpr(closed_jaxpr): jaxpr = closed_jaxpr.jaxpr print("invars:", jaxpr.invars) print("outvars:", jaxpr.outvars) print("constvars:", jaxpr.constvars) for eqn in jaxpr.eqns: print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params) print() print("jaxpr:", jaxpr) def foo(x): return x + 1 print("foo") print("=====") examine_jaxpr(jax.make_jaxpr(foo)(5)) print() def bar(w, b, x): return jnp.dot(w, x) + b + jnp.ones(5), x print("bar") print("=====") examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo ===== invars: [Var(id=140117887103104):int32[]] outvars: [Var(id=140117887103296):int32[]] constvars: [] equation: [Var(id=140117887103104):int32[], 1] add [Var(id=140117887103296):int32[]] {} jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) } bar ===== invars: [Var(id=140117843771968):float32[5,10], Var(id=140117843772032):float32[5], Var(id=140117843772096):float32[10]] outvars: [Var(id=140117843772352):float32[5], Var(id=140117843772096):float32[10]] constvars: [] equation: [Var(id=140117843771968):float32[5,10], Var(id=140117843772096):float32[10]] dot_general [Var(id=140117843772160):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')} equation: [Var(id=140117843772160):float32[5], Var(id=140117843772032):float32[5]] add [Var(id=140117843772224):float32[5]] {} equation: [1.0] broadcast_in_dim [Var(id=140117843772288):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()} equation: [Var(id=140117843772224):float32[5], Var(id=140117843772288):float32[5]] add [Var(id=140117843772352):float32[5]] {} jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let d:f32[5] = dot_general[ dimension_numbers=(([1], [0]), ([], [])) preferred_element_type=float32 ] a c e:f32[5] = add d b f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0 g:f32[5] = add e f in (g, c) }
jaxpr.invars
- Jaxpr 的invars
是一个输入变量列表,类似于 Python 函数的参数。jaxpr.outvars
- Jaxpr 的outvars
是由 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。jaxpr.constvars
-constvars
是一个变量列表,它们也是 Jaxpr 的输入之一,但对应于跟踪中的常量(我们稍后会更详细地讨论这些内容)。jaxpr.eqns
- 一个方程列表,实质上是 let 绑定。每个方程包含输入变量列表、输出变量列表和一个原语,用于评估输入以生成输出。每个方程还有一个params
,即参数字典。
总的来说,一个 Jaxpr 封装了一个简单的程序,可以使用输入进行评估以生成输出。稍后我们将详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以按我们想要的方式操作和评估的数据结构。
Jaxprs 有什么用处?
Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们从 Python 函数中分离出 Jaxprs,它为我们提供了一种转换用 Python 编写的数值程序的方法。
您的第一个解释器:invert
让我们尝试实现一个简单的函数“inverter”,它接收原始函数的输出,并返回产生这些输出的输入。现在,让我们专注于由其他可逆的一元函数组成的简单一元函数。
目标:
def f(x): return jnp.exp(jnp.tanh(x)) f_inv = inverse(f) assert jnp.allclose(f_inv(f(1.0)), 1.0)
我们将通过 (1) 将 f
追踪到 Jaxpr 中,然后 (2) 反向解释 Jaxpr 的方式来实现这一点。在反向解释 Jaxpr 过程中,对于每个方程,我们将在表中查找原语的逆,并应用它。
1. 追踪一个函数
让我们使用 make_jaxpr
来追踪一个函数到 Jaxpr 中。
# Importing Jax functions useful for tracing/interpreting. import numpy as np from functools import wraps from jax import core from jax import lax from jax._src.util import safe_map
jax.make_jaxpr
返回一个封闭的 Jaxpr,即一个已经与跟踪中的常量(literals
)捆绑在一起的 Jaxpr。
def f(x): return jnp.exp(jnp.tanh(x)) closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5)) print(closed_jaxpr.jaxpr) print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) } []
2. 评估 Jaxpr
在编写自定义 Jaxpr 解释器之前,让我们首先实现“默认”解释器 eval_jaxpr
,它按原样评估 Jaxpr,计算与未转换的原始 Python 函数相同的值。
为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。
def eval_jaxpr(jaxpr, consts, *args): # Mapping from variable -> value env = {} def read(var): # Literals are values baked into the Jaxpr if type(var) is core.Literal: return var.val return env[var] def write(var, val): env[var] = val # Bind args and consts to environment safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` for eqn in jaxpr.eqns: # Read inputs to equation from environment invals = safe_map(read, eqn.invars) # `bind` is how a primitive is called outvals = eqn.primitive.bind(*invals, **eqn.params) # Primitives may return multiple outputs or not if not eqn.primitive.multiple_results: outvals = [outvals] # Write the results of the primitive into the environment safe_map(write, eqn.outvars, outvals) # Read the final result of the Jaxpr from the environment return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5)) eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
注意,即使原始函数不返回平坦列表,eval_jaxpr
也将始终返回一个平坦列表。
此外,这个解释器不处理高阶原语(如 jit
和 pmap
),这些内容不在本指南讨论范围内。您可以参考 core.eval_jaxpr
(链接) 来查看此解释器不涵盖的边界情况。
自定义inverse
Jaxpr 解释器
inverse
解释器看起来与 eval_jaxpr
并无太大不同。我们首先设置注册表,将原语映射到它们的逆。然后编写一个自定义解释器,在注册表中查找原语。
结果表明,这个解释器看起来也类似于反向模式自动微分中使用的“转置”解释器,可以在此处找到:链接。
inverse_registry = {}
现在我们将为一些原语注册它们的逆。按照惯例,Jax 中的原语以 _p
结尾,而其中许多流行的原语位于 lax
中。
inverse_registry[lax.exp_p] = jnp.log inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
将首先跟踪函数,然后自定义解释 Jaxpr。让我们建立一个简单的框架。
def inverse(fun): @wraps(fun) def wrapped(*args, **kwargs): # Since we assume unary functions, we won't worry about flattening and # unflattening arguments. closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs) out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args) return out[0] return wrapped
现在我们只需要定义 inverse_jaxpr
,它将反向遍历 Jaxpr 并在可能时反转原语。
def inverse_jaxpr(jaxpr, consts, *args): env = {} def read(var): if type(var) is core.Literal: return var.val return env[var] def write(var, val): env[var] = val # Args now correspond to Jaxpr outvars safe_map(write, jaxpr.outvars, args) safe_map(write, jaxpr.constvars, consts) # Looping backward for eqn in jaxpr.eqns[::-1]: # outvars are now invars invals = safe_map(read, eqn.outvars) if eqn.primitive not in inverse_registry: raise NotImplementedError( f"{eqn.primitive} does not have registered inverse.") # Assuming a unary function outval = inverse_registryeqn.primitive safe_map(write, eqn.invars, [outval]) return safe_map(read, jaxpr.invars)
就是这样!
def f(x): return jnp.exp(jnp.tanh(x)) f_inv = inverse(f) assert jnp.allclose(f_inv(f(1.0)), 1.0)
重要的是,你可以通过 Jaxpr 解释器进行跟踪。
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
这就是向系统添加新转换所需的全部内容,而且你可以免费获得所有其他转换的组合!例如,我们可以在 inverse
中使用 jit
、vmap
和 grad
!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1\. ], dtype=float32, weak_type=True)
读者的练习
- 处理具有多个参数的原语,其中输入部分已知,例如
lax.add_p
,lax.mul_p
。 - 处理
xla_call
和xla_pmap
原语,这些原语不会与eval_jaxpr
和inverse_jaxpr
一样正常工作。
使用 C++ 和 CUDA 进行 GPU 自定义操作
原文:
jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html
JAX 预装有大量内置操作,但用户偶尔会遇到需要新操作但 JAX 不支持的情况。
为了适应这些情况,JAX 允许用户定义自定义操作,本教程旨在解释如何为 GPU 定义并在单 GPU 和多 GPU 环境中使用它们。
本教程包含来自 使用自定义 C++ 和 CUDA 代码扩展 JAX 的信息,并假设您熟悉 JAX 原语。
RMS 标准化
本教程将 RMS 标准化作为 JAX 中的自定义操作添加。请注意,可以直接使用 jax.numpy
表达 RMS 标准化。但是,我们使用它作为示例来展示如何为 GPU 创建自定义操作的过程。此操作在 gpu_ops/rms_norm_kernels.cu
中的 CUDA 代码已从 Apex 借用,并进行了修改,以消除对 PyTorch 的任何依赖。
高级步骤
本教程展示了如何编写自定义操作及其梯度。
在 C 中:每个新的 JAX 原语都需要按照以下步骤进行操作。
- 具有 CUDA 核心(核心)。
- 创建分派 CUDA 核心的 C 函数,该函数将由 XLA 调用。
- 创建描述符以传达计算所需的信息。
- 类型、形状和其他属性。
- 将 C 函数绑定到 Python
- 以创建描述符并在执行期间调用原语。
在 Python 中:您需要按照以下步骤进行操作。
- 定义新的 JAX 原语(指令/操作)
- 编写 Python 函数以使用原语构建图节点。
- 定义其抽象评估。
- 定义其降低到 MLIR。
- [可选] 定义梯度。
- [可选] 使用 custom_partitioning 或 shard_map 函数实现快速多 GPU 支持。
C 代码
参见 gpu_ops
代码列表,其中包含完整的 C++ 和 CUDA 文件代码列表。gpu_ops/rms_norm_kernels.cu
定义了以下函数,这些函数使用给定的 buffers
在指定的 stream
上启动 RMS 标准化核心。
namespace gpu_ops { void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); void rms_backward_affine(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); } // namespace gpu_ops
stream
是用于在 GPU 上执行任何核心的 CUDA 流。buffers
包含所有指向输入缓冲区的指针,后跟所有指向输出缓冲区的指针。opaque
是传递给自定义函数的任何额外信息的缓冲区,而opaque_len
是opaque
的长度。
在本教程中,我们将通过opaque
将一个RMSNormDescriptor
对象传递给这些函数。
namespace gpu_ops { enum ElementType { BF16, F16, F32, F64 }; struct RMSNormDescriptor { int n1; int n2; double eps; ElementType x_type; ElementType w_type; int part_grad_size; }; } // namespace gpu_ops
现在,我们需要通过pybind11
将这些函数以及ElementType
和RMSNormDescriptor
作为 Python 模块gpu_ops
公开。
pybind11::dict RMSNormRegistrations() { pybind11::dict dict; dict["rms_forward_affine_mixed_dtype"] = gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes); dict["rms_backward_affine"] = gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine); return dict; } PYBIND11_MODULE(gpu_ops, m) { m.def("get_rms_norm_registrations", &RMSNormRegistrations); m.def("create_rms_norm_descriptor", [](int n1, int n2, double eps, gpu_ops::ElementType x_type, gpu_ops::ElementType w_type, int part_grad_size) { return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{ n1, n2, eps, x_type, w_type, part_grad_size}); }); pybind11::enum_<gpu_ops::ElementType>(m, "ElementType") .value("BF16", gpu_ops::ElementType::BF16) .value("F16", gpu_ops::ElementType::F16) .value("F32", gpu_ops::ElementType::F32) .value("F64", gpu_ops::ElementType::F64); }
构建gpu_ops
扩展模块
我们使用上述代码构建了gpu_ops
Python 扩展模块。(请参阅 C++和 CUDA 文件的完整代码清单,查看gpu_ops
代码列表。)
python -m pip install pybind11==2.10.1 mkdir -p build pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())") python_executable=$(python -c 'import sys; print(sys.executable)') nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl strip build/gpu_ops$(${python_executable}-config --extension-suffix)
将 RMS 归一化添加到 JAX 作为自定义调用
gpu_ops
只是一个 Python 扩展模块,我们需要更多工作来将其插入到 JAX 中。
创建原语
我们首先创建了原语_rms_norm_fwd_p
和_rms_norm_bwd_p
,这些原语可以映射到自定义函数。我们为这些操作设置了multiple_results
属性为True
,表示该操作作为元组产生多个输出。当设置为False
时,该操作将产生单个输出而不是元组。有关更多详细信息,请参见How JAX primitives work。
from functools import partial import jax import jax.numpy as jnp import jax._src.test_util as jtu from build import gpu_ops from jax import core, dtypes from jax.interpreters import xla from jax.lib import xla_client # Create _rms_norm_fwd_p for forward operation. _rms_norm_fwd_p = core.Primitive("rms_norm_fwd") _rms_norm_fwd_p.multiple_results = True _rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p)) def rms_norm_fwd(x, weight, eps=1e-05): output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps) return output # Create _rms_norm_bwd_p for backward operation. _rms_norm_bwd_p = core.Primitive("rms_norm_bwd") _rms_norm_bwd_p.multiple_results = True _rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p)) def rms_norm_bwd(g, invvar, x, weight, eps): grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( g, invvar, x, weight, eps=eps ) return grad_input, grad_weight
降低到 MLIR 自定义调用
为了将自定义函数映射到新原语_rms_norm_fwd_p
和_rms_norm_bwd_p
,我们需要:
- 使用
xla_client.register_custom_call_target
注册自定义函数作为自定义调用目标,并且 - 注册将原语降低为 MLIR 自定义调用的降低函数,并使用注册的自定义调用目标。
下面的函数_rms_norm_fwd_cuda_lowering
和_rms_norm_bwd_cuda_lowering
通过gpu_ops
中的自定义目标将原语降低为 MLIR 自定义调用操作。这些函数已经注册到jax.interpreters.mlir.register_lowering
中。
注意,在降低函数中创建了一个RMSNormDescriptor
对象,并将其作为opaque
传递给自定义调用。
from functools import reduce from jax.interpreters import mlir from jax.interpreters.mlir import ir from jaxlib.hlo_helpers import custom_call # Register functions defined in gpu_ops as custom call target for GPUs for _name, _value in gpu_ops.get_rms_norm_registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="gpu") def element_type_to_descriptor_type_mapping(element_type): _element_type_to_descriptor_type_mapping = { ir.BF16Type.get(): gpu_ops.ElementType.BF16, ir.F16Type.get(): gpu_ops.ElementType.F16, ir.F32Type.get(): gpu_ops.ElementType.F32, ir.F64Type.get(): gpu_ops.ElementType.F64, } return _element_type_to_descriptor_type_mapping.get(element_type) def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps): x_type = ir.RankedTensorType(x.type) x_shape = x_type.shape w_type = ir.RankedTensorType(weight.type) w_shape = w_type.shape iv_element_type = ( ir.F32Type.get() if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()] else x_type.element_type ) n2 = reduce(lambda x, y: x * y, w_shape) n1 = reduce(lambda x, y: x * y, x_shape) // n2 opaque = gpu_ops.create_rms_norm_descriptor( n1, n2, eps, element_type_to_descriptor_type_mapping(x_type.element_type), element_type_to_descriptor_type_mapping(w_type.element_type), 0, # unused ) out = custom_call( b"rms_forward_affine_mixed_dtype", result_types=[ ir.RankedTensorType.get(x_shape, w_type.element_type), ir.RankedTensorType.get((n1,), iv_element_type), ], operands=[x, weight], backend_config=opaque, operand_layouts=default_layouts(x_shape, w_shape), result_layouts=default_layouts(x_shape, (n1,)), ).results return out mlir.register_lowering( _rms_norm_fwd_p, _rms_norm_fwd_cuda_lowering, platform="gpu", ) def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps): x_type = ir.RankedTensorType(x.type) x_shape = x_type.shape w_type = ir.RankedTensorType(weight.type) w_shape = w_type.shape iv_type = ir.RankedTensorType(invvar.type) n2 = reduce(lambda x, y: x * y, w_shape) n1 = reduce(lambda x, y: x * y, x_shape) // n2 part_grad_shape = ctx.avals_out[-1].shape opaque = gpu_ops.create_rms_norm_descriptor( n1, n2, eps, element_type_to_descriptor_type_mapping(x_type.element_type), element_type_to_descriptor_type_mapping(w_type.element_type), part_grad_shape[0], ) out = custom_call( b"rms_backward_affine", result_types=[ ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(w_shape, w_type.element_type), ir.RankedTensorType.get(part_grad_shape, iv_type.element_type), ], operands=[grad_output, invvar, x, weight], backend_config=opaque, operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape), result_layouts=default_layouts(x_shape, w_shape, part_grad_shape), ).results return out mlir.register_lowering( _rms_norm_bwd_p, _rms_norm_bwd_cuda_lowering, platform="gpu", )
JAX 中文文档(九)(5)https://developer.aliyun.com/article/1559677