JAX 教程
关键概念
本节简要介绍了 JAX 包的一些关键概念。
JAX 数组 (jax.Array
)
JAX 中的默认数组实现是 jax.Array
。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray
类型相似,但它也有一些重要的区别。
数组创建
我们通常不直接调用 jax.Array
构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy
提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()
、jax.numpy.linspace()
、jax.numpy.arange()
等。
import jax import jax.numpy as jnp x = jnp.arange(5) isinstance(x, jax.Array)
True
如果您在代码中使用 Python 类型注解,jax.Array
是 jax 数组对象的适当注释(参见 jax.typing
以获取更多讨论)。
数组设备和分片
JAX 数组对象具有一个 devices
方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:
x.devices()
{CpuDevice(id=0)}
一般来说,数组可能会在多个设备上 分片,您可以通过 sharding
属性进行检查:
x.sharding
SingleDeviceSharding(device=CpuDevice(id=0))
在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换
除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括
jax.jit()
: 即时(JIT)编译;参见即时编译jax.vmap()
: 向量化变换;参见自动向量化jax.grad()
: 梯度变换;参见自动微分
以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:
def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) selu_jit = jax.jit(selu) print(selu_jit(1.0))
1.05
通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:
@jax.jit def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
jit()
、vmap()
、grad()
等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪
变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。
您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:
@jax.jit def f(x): print(x) return x + 1 x = jnp.arange(5) result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)>
打印的值不是数组 x
,而是代表 x
的关键属性的 Tracer
实例,比如它的 shape
和 dtype
。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()
、vmap()
和 grad()
可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs
JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。
例如,考虑我们上面定义的 selu
函数:
def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
我们可以使用 jax.make_jaxpr()
实用程序来将该函数转换为一个 jaxpr,给定特定的输入:
x = jnp.arange(5.0) jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let b:bool[5] = gt a 0.0 c:f32[5] = exp a d:f32[5] = mul 1.6699999570846558 c e:f32[5] = sub d 1.6699999570846558 f:f32[5] = pjit[ name=_where jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let j:f32[5] = select_n g i h in (j,) } ] b a e k:f32[5] = mul 1.0499999523162842 f in (k,) }
与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees
JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。
以下是一些可以作为 pytrees 处理的对象的示例:
# (nested) list of parameters params = [1, 2, (jnp.arange(3), jnp.ones(2))] print(jax.tree.structure(params)) print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)]) [1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)} print(jax.tree.structure(params)) print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *}) [Array([[1., 1.], [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters from typing import NamedTuple class Params(NamedTuple): a: int b: float params = Params(1, 5.0) print(jax.tree.structure(params)) print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *])) [1, 5.0]
JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map()
可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce()
可以用于在树中的叶子上应用约简操作。
你可以在《使用 pytrees 教程》中了解更多信息。
即时编译
在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit()
变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。
如何工作 JAX 变换
在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。
查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()
:
import jax import jax.numpy as jnp global_list = [] def log2(x): global_list.append(x) ln_x = jnp.log(x) ln_2 = jnp.log(2.0) return ln_x / ln_2 print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let b:f32[] = log a c:f32[] = log 2.0 d:f32[] = div b c in (d,) }
文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。
重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x)
的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数 和 副作用 是陌生的术语,这在 🔪 JAX - The Sharp Bits 🔪: Pure Functions 中有稍微详细的解释。
非纯函数很危险,因为在 JAX 变换下它们可能无法按预期运行;它们可能会悄无声息地失败,或者产生意外的下游错误,如泄漏的跟踪器。此外,JAX 通常无法检测到是否存在副作用。(如果需要调试打印,请使用 jax.debug.print()
。要表达一般性副作用而牺牲性能,请参阅 jax.experimental.io_callback()
。要检查跟踪器泄漏而牺牲性能,请使用 jax.check_tracer_leaks()
)。
在跟踪时,JAX 通过 跟踪器 对象包装每个参数。这些跟踪器记录了在函数调用期间(即在常规 Python 中发生)对它们执行的所有 JAX 操作。然后,JAX 使用跟踪器记录重构整个函数。重构的输出是 jaxpr。由于跟踪器不记录 Python 的副作用,它们不会出现在 jaxpr 中。但是,副作用仍会在跟踪过程中发生。
注意:Python 的 print()
函数不是纯函数:文本输出是函数的副作用。因此,在跟踪期间,任何 print()
调用都将只发生一次,并且不会出现在 jaxpr 中:
def log2_with_print(x): print("printed x:", x) ln_x = jnp.log(x) ln_2 = jnp.log(2.0) return ln_x / ln_2 print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> { lambda ; a:f32[]. let b:f32[] = log a c:f32[] = log 2.0 d:f32[] = div b c in (d,) }
看看打印出来的 x
是一个 Traced
对象?这就是 JAX 内部的工作原理。
Python 代码至少运行一次的事实严格来说是一个实现细节,因此不应依赖它。然而,在调试时理解它是有用的,因为您可以在计算的中间值打印出来。
一个关键的理解点是,jaxpr 捕捉函数在给定参数上执行的方式。例如,如果我们有一个 Python 条件语句,jaxpr 只会了解我们选择的分支:
def log2_if_rank_2(x): if x.ndim == 2: ln_x = jnp.log(x) ln_2 = jnp.log(2.0) return ln_x / ln_2 else: return x print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let in (a,) }
JIT 编译函数
正如之前所解释的,JAX 使得操作能够使用相同的代码在 CPU/GPU/TPU 上执行。让我们看一个计算缩放指数线性单元(SELU)的例子,这是深度学习中常用的操作:
import jax import jax.numpy as jnp def selu(x, alpha=1.67, lambda_=1.05): return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = jnp.arange(1000000) %timeit selu(x).block_until_ready()
2.81 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
上述代码一次只发送一个操作到加速器。这限制了 XLA 编译器优化我们函数的能力。
自然地,我们希望尽可能多地向 XLA 编译器提供代码,以便它能够完全优化它。为此,JAX 提供了jax.jit()
转换,它将即时编译一个与 JAX 兼容的函数。下面的示例展示了如何使用 JIT 加速前述函数。
selu_jit = jax.jit(selu) # Pre-compile the function before timing... selu_jit(x).block_until_ready() %timeit selu_jit(x).block_until_ready()
1.01 ms ± 2.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
刚刚发生了什么事:
- 我们定义了
selu_jit
作为selu
的编译版本。 - 我们在
x
上调用了selu_jit
一次。这是 JAX 进行其追踪的地方 - 它需要一些输入来包装成追踪器。然后,jaxpr 使用 XLA 编译成非常高效的代码,针对您的 GPU 或 TPU 进行优化。最后,编译的代码被执行以满足调用。后续对selu_jit
的调用将直接使用编译后的代码,跳过 Python 实现。(如果我们没有单独包括预热调用,一切仍将正常运行,但编译时间将包含在基准测试中。因为我们在基准测试中运行多个循环,所以仍会更快,但这不是公平的比较。) - 我们计时了编译版本的执行速度。(注意使用
block_until_ready()
,这是由于 JAX 的异步调度所需。)
为什么我们不能把所有东西都即时编译(JIT)呢?
在上面的例子中,你可能会想知道我们是否应该简单地对每个函数应用jax.jit()
。要理解为什么不是这样,并且何时需要/不需要应用jit
,让我们首先检查一些jit
不适用的情况。
# Condition on value of x. def f(x): if x > 0: return x else: return 2 * x jax.jit(f)(10) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. The error occurred while tracing the function f at /tmp/ipykernel_1169/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n. def g(x, n): i = 0 while i < n: i += 1 return x + i jax.jit(g)(10, 20) # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. The error occurred while tracing the function g at /tmp/ipykernel_1169/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
在这两种情况下的问题是,我们尝试使用运行时值来条件追踪时间流程。在 JIT 中追踪的值,例如这里的x
和n
,只能通过它们的静态属性(如shape
或dtype
)影响控制流,而不能通过它们的值。有关 Python 控制流与 JAX 交互的更多详细信息,请参见🔪 JAX - The Sharp Bits 🔪: Control Flow。
处理这个问题的一种方法是重写代码,避免在值条件上使用条件语句。另一种方法是使用特殊的控制流操作符,例如jax.lax.cond()
。然而,有时这并不可行或实际。在这种情况下,可以考虑只对函数的部分进行 JIT 编译。例如,如果函数中最消耗计算资源的部分在循环内部,我们可以只对内部的那部分进行 JIT 编译(但务必查看关于缓存的下一节,以避免出现问题):
# While loop conditioned on x and n with a jitted body. @jax.jit def loop_body(prev_i): return prev_i + 1 def g_inner_jitted(x, n): i = 0 while i < n: i = loop_body(i) return x + i g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)
将参数标记为静态的
如果我们确实需要对具有输入值条件的函数进行 JIT 编译,我们可以告诉 JAX 通过指定static_argnums
或static_argnames
来帮助自己获取特定输入的较少抽象的追踪器。这样做的成本是生成的 jaxpr 和编译的工件依赖于传递的特定值,因此 JAX 将不得不针对指定静态输入的每个新值重新编译函数。只有在函数保证看到有限的静态值集时,这才是一个好策略。
f_jit_correct = jax.jit(f, static_argnums=0) print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n']) print(g_jit_correct(10, 20))
30
当使用jit
作为装饰器时,要指定这些参数的一种常见模式是使用 Python 的functools.partial()
:
from functools import partial @partial(jax.jit, static_argnames=['n']) def g_jit_decorated(x, n): i = 0 while i < n: i += 1 return x + i print(g_jit_decorated(10, 20))
30
JIT 和缓存
通过第一次 JIT 调用的编译开销,了解jax.jit()
如何以及何时缓存先前的编译是有效使用它的关键。
假设我们定义f = jax.jit(g)
。当我们首次调用f
时,它会被编译,并且生成的 XLA 代码将被缓存。后续调用f
将重用缓存的代码。这就是jax.jit
如何弥补编译的前期成本。
如果我们指定了static_argnums
,那么缓存的代码将仅在标记为静态的参数值相同时使用。如果它们中任何一个发生更改,将重新编译。如果存在许多值,则您的程序可能会花费更多时间进行编译,而不是逐个执行操作。
避免在循环或其他 Python 作用域内定义的临时函数上调用jax.jit()
。对于大多数情况,JAX 能够在后续调用jax.jit()
时使用编译和缓存的函数。然而,由于缓存依赖于函数的哈希值,在重新定义等价函数时会引发问题。这将导致每次在循环中不必要地重新编译:
from functools import partial def unjitted_loop_body(prev_i): return prev_i + 1 def g_inner_jitted_partial(x, n): i = 0 while i < n: # Don't do this! each time the partial returns # a function with different hash i = jax.jit(partial(unjitted_loop_body))(i) return x + i def g_inner_jitted_lambda(x, n): i = 0 while i < n: # Don't do this!, lambda will also return # a function with a different hash i = jax.jit(lambda x: unjitted_loop_body(x))(i) return x + i def g_inner_jitted_normal(x, n): i = 0 while i < n: # this is OK, since JAX can find the # cached, compiled function i = jax.jit(unjitted_loop_body)(i) return x + i print("jit called in a loop with partials:") %timeit g_inner_jitted_partial(10, 20).block_until_ready() print("jit called in a loop with lambdas:") %timeit g_inner_jitted_lambda(10, 20).block_until_ready() print("jit called in a loop with caching:") %timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials: 217 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) jit called in a loop with lambdas: 219 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) jit called in a loop with caching: 2.33 ms ± 29.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
JAX 中文文档(二)(2)https://developer.aliyun.com/article/1559669