JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。
长话短说:
- 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
- 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
- 使用 jax.grad 求导。
- 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。
函数式编程
JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。
例子#1
以下代码片段是一个非功能纯的示例。
import jax.numpy as jnp
bias = jnp.array(0)
def impure_example(x):
total = x + bias
return total
注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。
例子#2
这是一个pure的例子。
def pure_example(x, weights, bias):
activation = weights @ x + bias
return activation
在这里,pure_example 是独立的:所有参数都作为参数传递。
确定性采样器
在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。
函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥
import jax
key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)
此外,您有责任为任何后续调用推进“随机状态”。
key = jax.random.PRNGKey(43)
# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)
..
jit
您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:
from jax import jit
@jit
def selu(x, α=1.67, λ=1.05):
return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)
JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。
gard
JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。
from jax import grad
def f(x):
return x + 0.5 * x**2
df_dx = grad(f)
d2f_dx2 = grad(grad(f))
正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。
vmap 和 pmap
矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。
计算线性函数:
def linear(x):
return weights @ x
在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:
def naively_batched_linear(X_batched):
return jnp.stack([linear(x) for x in X_batched])
相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:
def vmap_batched_linear(X_batched):
return vmap(linear)(X_batched)