JAX 中文文档(二)(1)

简介: JAX 中文文档(二)


原文:jax.readthedocs.io/en/latest/

JAX 教程

原文:jax.readthedocs.io/en/latest/tutorials.html

  • 快速入门
  • 关键概念
  • 即时编译
  • 自动向量化
  • 自动微分
  • 调试入门
  • 伪随机数
  • 使用 pytrees 工作
  • 分片计算入门
  • 有状态计算

关键概念

原文:jax.readthedocs.io/en/latest/key-concepts.html

本节简要介绍了 JAX 包的一些关键概念。

JAX 数组 (jax.Array)

JAX 中的默认数组实现是 jax.Array。在许多方面,它与您可能熟悉的 NumPy 包中的 numpy.ndarray 类型相似,但它也有一些重要的区别。

数组创建

我们通常不直接调用 jax.Array 构造函数,而是通过 JAX API 函数创建数组。例如,jax.numpy 提供了类似 NumPy 风格的数组构造功能,如 jax.numpy.zeros()jax.numpy.linspace()jax.numpy.arange() 等。

import jax
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array) 
True 

如果您在代码中使用 Python 类型注解,jax.Array 是 jax 数组对象的适当注释(参见 jax.typing 以获取更多讨论)。

数组设备和分片

JAX 数组对象具有一个 devices 方法,允许您查看数组内容存储在哪里。在最简单的情况下,这将是单个 CPU 设备:

x.devices() 
{CpuDevice(id=0)} 

一般来说,数组可能会在多个设备上 分片,您可以通过 sharding 属性进行检查:

x.sharding 
SingleDeviceSharding(device=CpuDevice(id=0)) 

在这里,数组位于单个设备上,但通常情况下,JAX 数组可以分布在多个设备或者多个主机上。要了解更多关于分片数组和并行计算的信息,请参阅分片计算介绍## 变换

除了用于操作数组的函数外,JAX 还包括许多用于操作 JAX 函数的变换。这些变换包括

  • jax.jit(): 即时(JIT)编译;参见即时编译
  • jax.vmap(): 向量化变换;参见自动向量化
  • jax.grad(): 梯度变换;参见自动微分

以及其他几个。变换接受一个函数作为参数,并返回一个新的转换后的函数。例如,这是您可能如何对一个简单的 SELU 函数进行 JIT 编译:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
selu_jit = jax.jit(selu)
print(selu_jit(1.0)) 
1.05 

通常情况下,您会看到使用 Python 的装饰器语法来应用变换以方便操作:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

jit()vmap()grad() 等变换对于有效使用 JAX 至关重要,我们将在后续章节中详细介绍它们。## 跟踪

变换背后的魔法是跟踪器的概念。跟踪器是数组对象的抽象替身,传递给 JAX 函数,以提取函数编码的操作序列。

您可以通过打印转换后的 JAX 代码中的任何数组值来看到这一点;例如:

@jax.jit
def f(x):
  print(x)
  return x + 1
x = jnp.arange(5)
result = f(x) 
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=1/0)> 

打印的值不是数组 x,而是代表 x 的关键属性的 Tracer 实例,比如它的 shapedtype。通过使用追踪值执行函数,JAX 可以确定函数编码的操作序列,然后在实际执行这些操作之前执行转换:例如 jit()vmap()grad() 可以将输入操作序列映射到变换后的操作序列。 ## Jaxprs

JAX 对操作序列有自己的中间表示形式,称为 jaxpr。jaxpr(JAX exPRession 的缩写)是一个函数程序的简单表示,包含一系列原始操作。

例如,考虑我们上面定义的 selu 函数:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 

我们可以使用 jax.make_jaxpr() 实用程序来将该函数转换为一个 jaxpr,给定特定的输入:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x) 
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) } 

与 Python 函数定义相比,可以看出它编码了函数表示的精确操作序列。我们稍后将深入探讨 JAX 内部的 jaxprs:jaxpr 语言。 ## Pytrees

JAX 函数和转换基本上操作数组,但实际上编写处理数组集合的代码更为方便:例如,神经网络可能会将其参数组织在具有有意义键的数组字典中。与其逐案处理这类结构,JAX 依赖于 pytree 抽象来统一处理这些集合。

以下是一些可以作为 pytrees 处理的对象的示例:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)] 
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5] 
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
  a: int
  b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params)) 
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0] 

JAX 提供了许多用于处理 PyTrees 的通用实用程序;例如函数 jax.tree.map() 可以用于将函数映射到树中的每个叶子,而 jax.tree.reduce() 可以用于在树中的叶子上应用约简操作。

你可以在《使用 pytrees 教程》中了解更多信息。

即时编译

原文:jax.readthedocs.io/en/latest/jit-compilation.html

在这一部分,我们将进一步探讨 JAX 的工作原理,以及如何使其性能卓越。我们将讨论 jax.jit() 变换,它将 JAX Python 函数进行即时编译,以便在 XLA 中高效执行。

如何工作 JAX 变换

在前一节中,我们讨论了 JAX 允许我们转换 Python 函数的能力。JAX 通过将每个函数减少为一系列原始操作来实现这一点,每个原始操作代表一种基本的计算单位。

查看函数背后原始操作序列的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp
global_list = []
def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0)) 
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

文档的理解 Jaxprs 部分提供了有关上述输出含义的更多信息。

重要的是要注意,jaxpr 不捕获函数中存在的副作用:其中没有对 global_list.append(x) 的任何内容。这是一个特性,而不是一个错误:JAX 变换旨在理解无副作用(也称为函数纯粹)的代码。如果 纯函数副作用 是陌生的术语,这在 🔪 JAX - The Sharp Bits 🔪: Pure Functions 中有稍微详细的解释。

非纯函数很危险,因为在 JAX 变换下它们可能无法按预期运行;它们可能会悄无声息地失败,或者产生意外的下游错误,如泄漏的跟踪器。此外,JAX 通常无法检测到是否存在副作用。(如果需要调试打印,请使用 jax.debug.print()。要表达一般性副作用而牺牲性能,请参阅 jax.experimental.io_callback()。要检查跟踪器泄漏而牺牲性能,请使用 jax.check_tracer_leaks())。

在跟踪时,JAX 通过 跟踪器 对象包装每个参数。这些跟踪器记录了在函数调用期间(即在常规 Python  中发生)对它们执行的所有 JAX 操作。然后,JAX 使用跟踪器记录重构整个函数。重构的输出是 jaxpr。由于跟踪器不记录 Python  的副作用,它们不会出现在 jaxpr 中。但是,副作用仍会在跟踪过程中发生。

注意:Python 的 print() 函数不是纯函数:文本输出是函数的副作用。因此,在跟踪期间,任何 print() 调用都将只发生一次,并且不会出现在 jaxpr 中:

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2
print(jax.make_jaxpr(log2_with_print)(3.)) 
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) } 

看看打印出来的 x 是一个 Traced 对象?这就是 JAX 内部的工作原理。

Python 代码至少运行一次的事实严格来说是一个实现细节,因此不应依赖它。然而,在调试时理解它是有用的,因为您可以在计算的中间值打印出来。

一个关键的理解点是,jaxpr 捕捉函数在给定参数上执行的方式。例如,如果我们有一个 Python 条件语句,jaxpr 只会了解我们选择的分支:

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3]))) 
{ lambda ; a:i32[3]. let  in (a,) } 

JIT 编译函数

正如之前所解释的,JAX 使得操作能够使用相同的代码在 CPU/GPU/TPU 上执行。让我们看一个计算缩放指数线性单元SELU)的例子,这是深度学习中常用的操作:

import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready() 
2.81 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 

上述代码一次只发送一个操作到加速器。这限制了 XLA 编译器优化我们函数的能力。

自然地,我们希望尽可能多地向 XLA 编译器提供代码,以便它能够完全优化它。为此,JAX 提供了jax.jit()转换,它将即时编译一个与 JAX 兼容的函数。下面的示例展示了如何使用 JIT 加速前述函数。

selu_jit = jax.jit(selu)
# Pre-compile the function before timing...
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready() 
1.01 ms ± 2.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

刚刚发生了什么事:

  1. 我们定义了selu_jit作为selu的编译版本。
  2. 我们在x上调用了selu_jit一次。这是 JAX 进行其追踪的地方 - 它需要一些输入来包装成追踪器。然后,jaxpr 使用 XLA 编译成非常高效的代码,针对您的 GPU 或 TPU 进行优化。最后,编译的代码被执行以满足调用。后续对selu_jit的调用将直接使用编译后的代码,跳过 Python 实现。(如果我们没有单独包括预热调用,一切仍将正常运行,但编译时间将包含在基准测试中。因为我们在基准测试中运行多个循环,所以仍会更快,但这不是公平的比较。)
  3. 我们计时了编译版本的执行速度。(注意使用block_until_ready(),这是由于 JAX 的异步调度所需。)

为什么我们不能把所有东西都即时编译(JIT)呢?

在上面的例子中,你可能会想知道我们是否应该简单地对每个函数应用jax.jit()。要理解为什么不是这样,并且何时需要/不需要应用jit,让我们首先检查一些jit不适用的情况。

# Condition on value of x.
def f(x):
  if x > 0:
    return x
  else:
    return 2 * x
jax.jit(f)(10)  # Raises an error 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_1169/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 
# While loop conditioned on x and n.
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i
jax.jit(g)(10, 20)  # Raises an error 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function g at /tmp/ipykernel_1169/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 

在这两种情况下的问题是,我们尝试使用运行时值来条件追踪时间流程。在 JIT 中追踪的值,例如这里的xn,只能通过它们的静态属性(如shapedtype)影响控制流,而不能通过它们的值。有关 Python 控制流与 JAX 交互的更多详细信息,请参见🔪 JAX - The Sharp Bits 🔪: Control Flow

处理这个问题的一种方法是重写代码,避免在值条件上使用条件语句。另一种方法是使用特殊的控制流操作符,例如jax.lax.cond()。然而,有时这并不可行或实际。在这种情况下,可以考虑只对函数的部分进行 JIT 编译。例如,如果函数中最消耗计算资源的部分在循环内部,我们可以只对内部的那部分进行 JIT 编译(但务必查看关于缓存的下一节,以避免出现问题):

# While loop conditioned on x and n with a jitted body.
@jax.jit
def loop_body(prev_i):
  return prev_i + 1
def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i
g_inner_jitted(10, 20) 
Array(30, dtype=int32, weak_type=True) 

将参数标记为静态的

如果我们确实需要对具有输入值条件的函数进行 JIT 编译,我们可以告诉 JAX 通过指定static_argnumsstatic_argnames来帮助自己获取特定输入的较少抽象的追踪器。这样做的成本是生成的 jaxpr 和编译的工件依赖于传递的特定值,因此 JAX 将不得不针对指定静态输入的每个新值重新编译函数。只有在函数保证看到有限的静态值集时,这才是一个好策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10)) 
10 
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20)) 
30 

当使用jit作为装饰器时,要指定这些参数的一种常见模式是使用 Python 的functools.partial()

from functools import partial
@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i
print(g_jit_decorated(10, 20)) 
30 

JIT 和缓存

通过第一次 JIT 调用的编译开销,了解jax.jit()如何以及何时缓存先前的编译是有效使用它的关键。

假设我们定义f = jax.jit(g)。当我们首次调用f时,它会被编译,并且生成的 XLA 代码将被缓存。后续调用f将重用缓存的代码。这就是jax.jit如何弥补编译的前期成本。

如果我们指定了static_argnums,那么缓存的代码将仅在标记为静态的参数值相同时使用。如果它们中任何一个发生更改,将重新编译。如果存在许多值,则您的程序可能会花费更多时间进行编译,而不是逐个执行操作。

避免在循环或其他 Python 作用域内定义的临时函数上调用jax.jit()。对于大多数情况,JAX 能够在后续调用jax.jit()时使用编译和缓存的函数。然而,由于缓存依赖于函数的哈希值,在重新定义等价函数时会引发问题。这将导致每次在循环中不必要地重新编译:

from functools import partial
def unjitted_loop_body(prev_i):
  return prev_i + 1
def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i
def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i
def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i
print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()
print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()
print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready() 
jit called in a loop with partials:
217 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
219 ms ± 5.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.33 ms ± 29.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 


JAX 中文文档(二)(2)https://developer.aliyun.com/article/1559669

相关文章
|
3月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
33 1
|
3月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
22 1
|
3月前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
59 0
|
3月前
JAX 中文文档(九)(3)
JAX 中文文档(九)
30 0
|
3月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
19 0
|
3月前
|
存储 机器学习/深度学习 TensorFlow
JAX 中文文档(七)(5)
JAX 中文文档(七)
23 0
|
3月前
|
机器学习/深度学习 程序员 编译器
JAX 中文文档(三)(1)
JAX 中文文档(三)
28 0
|
3月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
26 0
|
3月前
|
存储 缓存 索引
JAX 中文文档(五)(3)
JAX 中文文档(五)
45 0
|
3月前
|
机器学习/深度学习 测试技术 索引
JAX 中文文档(二)(4)
JAX 中文文档(二)
38 0