JAX 中文文档(二)(3)

简介: JAX 中文文档(二)

JAX 中文文档(二)(2)https://developer.aliyun.com/article/1559669


调试介绍

原文:jax.readthedocs.io/en/latest/debugging.html

本节介绍了一组内置的 JAX 调试方法 — jax.debug.print()jax.debug.breakpoint()jax.debug.callback() — 您可以将其与各种 JAX 转换一起使用。

让我们从 jax.debug.print() 开始。

JAX 的 debug.print 用于高级别

TL;DR 这是一个经验法则:

  • 对于使用 jax.jit()jax.vmap() 和其他动态数组值的跟踪,使用 jax.debug.print()
  • 对于静态值(例如 dtypes 和数组形状),使用 Python print()

回顾即时编译时,使用 jax.jit() 转换函数时,Python 代码在数组的抽象跟踪器的位置执行。因此,Python print() 函数只会打印此跟踪器值:

import jax
import jax.numpy as jnp
@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y
result = f(2.) 
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 

Python 的 print 在跟踪时间执行,即在运行时值存在之前。如果要打印实际的运行时值,可以使用 jax.debug.print()

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y
result = f(2.) 
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314 

类似地,在 jax.vmap() 内部,使用 Python 的 print 只会打印跟踪器;要打印正在映射的值,请使用 jax.debug.print()

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y
xs = jnp.arange(3.)
result = jax.vmap(f)(xs) 
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314 

这里是使用 jax.lax.map() 的结果,它是一个顺序映射而不是向量化:

result = jax.lax.map(f, xs) 
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314 

注意顺序不同,如 jax.vmap()jax.lax.map() 以不同方式计算相同结果。在调试时,评估顺序的细节正是您可能需要检查的。

下面是一个关于 jax.grad() 的示例,其中 jax.debug.print() 仅打印前向传递。在这种情况下,行为类似于 Python 的 print(),但如果在调用期间应用 jax.jit(),它是一致的。

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2
result = jax.grad(f)(1.) 
jax.debug.print(x) -> 1.0 

有时,当参数彼此不依赖时,调用 jax.debug.print() 可能会以不同的顺序打印它们,当使用 JAX 转换进行分阶段时。如果需要原始顺序,例如首先是 x: ... 然后是 y: ...,请添加 ordered=True 参数。

例如:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y
f(1, 2) 
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2 
Array(3, dtype=int32, weak_type=True) 

要了解更多关于 jax.debug.print() 及其详细信息,请参阅高级调试。

JAX 的 debug.breakpoint 用于类似 pdb 的调试

TL;DR 使用 jax.debug.breakpoint() 暂停您的 JAX 程序执行以检查值。

要在调试期间暂停编译的 JAX 程序的某些点,您可以使用 jax.debug.breakpoint()。提示类似于 Python 的 pdb,允许您检查调用堆栈中的值。实际上,jax.debug.breakpoint()jax.debug.callback() 的应用,用于捕获有关调用堆栈的信息。

要在 breakpoint 调试会话期间打印所有可用命令,请使用 help 命令。(完整的调试器命令、其强大之处及限制在高级调试中有详细介绍。)

这是调试器会话可能看起来的示例:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution 

[外链图片转存中…(img-bP2YtS7x-1718950373556)]

对于依赖值的断点,您可以使用像jax.lax.cond()这样的运行时条件:

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 1.) # ==> No breakpoint 
Array(2., dtype=float32, weak_type=True) 
f(2., 0.) # ==> Pauses during execution 

JAX 调试回调以增强调试期间的控制

jax.debug.print()jax.debug.breakpoint()都使用更灵活的jax.debug.callback()实现,它通过 Python 回调执行主机端逻辑,提供更大的控制。它与jax.jit()jax.vmap()jax.grad()和其他转换兼容(有关更多信息,请参阅外部回调的回调类型表)。

例如:

import logging
def log_value(x):
  logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x
f(1.0); 
WARNING:root:Logged value: 1.0 

此回调与其他转换兼容,包括jax.vmap()jax.grad()

x = jnp.arange(5.0)
jax.vmap(f)(x); 
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0 
jax.grad(f)(1.0); 
WARNING:root:Logged value: 1.0 

这使得jax.debug.callback()在通用调试中非常有用。

您可以在外部回调中了解更多关于jax.debug.callback()和其他类型 JAX 回调的信息。

下一步

查看高级调试以了解更多关于在 JAX 中调试的信息。

伪随机数

原文:jax.readthedocs.io/en/latest/random-numbers.html

本节将重点讨论 jax.random 和伪随机数生成(PRNG);即,通过算法生成数列,其特性近似于从适当分布中抽样的随机数列的过程。

PRNG 生成的序列并非真正随机,因为它们实际上由其初始值决定,通常称为 seed,并且每一步的随机抽样都是由从一个样本到下一个样本传递的 state 的确定性函数决定。

伪随机数生成是任何机器学习或科学计算框架的重要组成部分。一般而言,JAX 力求与 NumPy 兼容,但伪随机数生成是一个显著的例外。

为了更好地理解 JAX 和 NumPy 在随机数生成方法上的差异,我们将在本节中讨论两种方法。

NumPy 中的随机数

NumPy 中的伪随机数生成由 numpy.random 模块本地支持。在 NumPy 中,伪随机数生成基于全局 state,可以使用 numpy.random.seed() 将其设置为确定性初始条件。

import numpy as np
np.random.seed(0) 

您可以使用以下命令检查状态的内容。

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')
print_truncated_random_state() 
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ... 

每次对随机函数调用都会更新 state

np.random.seed(0)
print_truncated_random_state() 
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ... 
_ = np.random.uniform()
print_truncated_random_state() 
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ... 

NumPy 允许您在单个函数调用中同时抽取单个数字或整个向量的数字。例如,您可以通过以下方式从均匀分布中抽取一个包含 3 个标量的向量:

np.random.seed(0)
print(np.random.uniform(size=3)) 
[0.5488135  0.71518937 0.60276338] 

NumPy 提供了顺序等效保证,这意味着连续抽取 N 个数字或一次抽样 N 个数字的向量将得到相同的伪随机序列:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3)) 
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338] 

JAX 中的随机数

JAX 的随机数生成与 NumPy 的方式有重要的区别,因为 NumPy 的 PRNG 设计使得同时保证多种理想特性变得困难。具体而言,在 JAX 中,我们希望 PRNG 生成是:

  1. 可复现的,
  2. 可并行化,
  3. 可向量化。

我们将在接下来讨论原因。首先,我们将集中讨论基于全局状态的伪随机数生成设计的影响。考虑以下代码:

import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo()) 
1.9791922366721637 

函数 foo 对从均匀分布中抽样的两个标量求和。

如果我们假设 bar()baz()  的执行顺序是可预测的,那么此代码的输出只能满足要求 #1。在 NumPy 中,这不是问题,因为它总是按照 Python  解释器定义的顺序执行代码。然而,在 JAX 中,情况就比较复杂了:为了执行效率,我们希望 JIT  编译器可以自由地重新排序、省略和融合我们定义的函数中的各种操作。此外,在多设备环境中执行时,每个进程需要同步全局状态,这会影响执行效率。

明确的随机状态

为了避免这个问题,JAX 避免使用隐式的全局随机状态,而是通过随机 key 显式地跟踪状态:

from jax import random
key = random.key(42)
print(key) 
Array((), dtype=key<fry>) overlaying:
[ 0 42] 

注意

本节使用由 jax.random.key() 生成的新型类型化 PRNG key,而不是由 jax.random.PRNGKey() 生成的旧型原始 PRNG key。有关详情,请参阅 JEP 9263:类型化 key 和可插拔 RNG。

一个 key 是一个具有特定 PRNG 实现对应的特殊数据类型的数组;在默认实现中,每个 key 由一对 uint32 值支持。

key 实际上是 NumPy 隐藏状态对象的替代品,但我们显式地将其传递给 jax.random() 函数。重要的是,随机函数消耗 key,但不修改它:将相同的 key 对象传递给随机函数将始终生成相同的样本。

print(random.normal(key))
print(random.normal(key)) 
-0.18471177
-0.18471177 

即使使用不同的 random API,重复使用相同的 key 也可能导致相关的输出,这通常是不可取的。

经验法则是:永远不要重复使用 key(除非你希望得到相同的输出)。

为了生成不同且独立的样本,你必须在将 key 传递给随机函数之前显式地调用 split()

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.
  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().
  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration. 
draw 0: 1.369469404220581
draw 1: -0.19947023689746857
draw 2: -2.298278331756592 

(这里调用 del 并非必须,但我们这样做是为了强调一旦使用过的 key 不应再次使用。)

jax.random.split() 是一个确定性函数,它将一个 key 转换为若干独立(在伪随机性意义上)的新 key。我们保留其中一个作为 new_key,可以安全地将额外生成的唯一 subkey 作为随机函数的输入,然后永久丢弃它。如果你需要从正态分布中获取另一个样本,你需要再次执行 split(key),以此类推:关键的一点是,你永远不要重复使用同一个 key

调用 split(key) 的输出的哪一部分被称为 key,哪一部分被称为 subkey 并不重要。它们都是具有相同状态的独立 keykey/subkey 命名约定是一种典型的使用模式,有助于跟踪 key 如何被消耗:subkey 被用于随机函数的直接消耗,而 key 则保留用于稍后生成更多的随机性。

通常,上述示例可以简洁地写成

key, subkey = random.split(key) 

这会自动丢弃旧 key。值得注意的是,split() 不仅可以创建两个 key,还可以创建多个:

key, *forty_two_subkeys = random.split(key, num=43) 

缺乏顺序等价性

NumPy 和 JAX 随机模块之间的另一个区别涉及到上述的顺序等价性保证。

与 NumPy 类似,JAX 的随机模块也允许对向量进行抽样。但是,JAX 不提供顺序等价性保证,因为这样做会干扰 SIMD 硬件上的向量化(上述要求 #3)。

在下面的示例中,使用三个子密钥分别从正态分布中抽取 3 个值,与使用单个密钥并指定shape=(3,)会得到不同的结果:

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,))) 
individually: [-0.04838832  0.10796154 -1.2226542 ]
all at once:  [ 0.18693547 -1.2806505  -1.5593132 ] 

缺乏顺序等价性使我们能够更高效地编写代码;例如,不用通过顺序循环生成上述的sequence,而是可以使用jax.vmap()以向量化方式计算相同的结果:

import jax
print("vectorized:", jax.vmap(random.normal)(subkeys)) 
vectorized: [-0.04838832  0.10796154 -1.2226542 ] 

下一步

欲了解更多关于 JAX 随机数的信息,请参阅jax.random模块的文档。如果您对 JAX 随机数生成器的设计细节感兴趣,请参阅 JAX PRNG 设计。


JAX 中文文档(二)(4)https://developer.aliyun.com/article/1559671

相关文章
|
9天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
13 1
|
9天前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
16 0
|
9天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
15 0
|
9天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0
|
9天前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
9 0
|
9天前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
11 0
|
9天前
|
存储 缓存 API
JAX 中文文档(五)(1)
JAX 中文文档(五)
10 0
|
9天前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
7 0
|
9天前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
11 0
|
9天前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
12 0