JAX 中文文档(一)(4)

简介: JAX 中文文档(一)

JAX 中文文档(一)(3)https://developer.aliyun.com/article/1559832


🔪 动态形状

在像jax.jitjax.vmapjax.grad等变换中使用的 JAX 代码要求所有输出数组和中间数组具有静态形状:即形状不能依赖于其他数组中的值。

例如,如果您正在实现自己的版本jnp.nansum,您可能会从以下内容开始:

def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum() 

在 JIT 和其他转换之外,这可以正常工作:

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x)) 
10.0 

如果尝试将jax.jit或另一个转换应用于此函数,则会报错:

jax.jit(nansum)(x) 
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError 

问题在于x_without_nans的大小取决于x中的值,这另一种方式说它的大小是动态的。通常在 JAX 中,可以通过其他方式绕过对动态大小数组的需求。例如,在这里可以使用jnp.where的三参数形式,将 NaN 值替换为零,从而计算出相同的结果,同时避免动态形状:

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()
print(nansum_2(x)) 
10.0 

在其他情况下,类似的技巧可以发挥作用,其中动态形状数组出现。

🔪 NaNs

调试 NaNs

如果要追踪你的函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:

  • 设置JAX_DEBUG_NANS=True环境变量;
  • 在你的主文件顶部添加jax.config.update("jax_debug_nans", True)
  • 在你的主文件中添加jax.config.parse_flags_with_absl(),然后使用命令行标志设置选项,如--jax_debug_nans=True

这将导致 NaN 产生时立即终止计算。打开此选项会在由 XLA 产生的每个浮点类型值上添加 NaN 检查。这意味着对于不在@jit下的每个基元操作,值将被拉回主机并作为 ndarrays 进行检查。对于在@jit下的代码,将检查每个@jit函数的输出,如果存在 NaN,则将以逐个操作的去优化模式重新运行函数,有效地一次移除一个@jit级别。

可能会出现棘手的情况,比如只在@jit下出现的 NaN,但在去优化模式下却不会产生。在这种情况下,你会看到警告消息打印出来,但你的代码将继续执行。

如果在梯度评估的反向传递中产生 NaNs,当在堆栈跟踪中引发异常时,您将位于 backward_pass 函数中,这本质上是一个简单的 jaxpr 解释器,以反向遍历原始操作序列。在下面的示例中,我们使用命令行env JAX_DEBUG_NANS=True ipython启动了一个 ipython repl,然后运行了以下命令:

In [1]: import jax.numpy as jnp
In [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-2-f2e2c413b437> in <module>()
----> 1 jnp.divide(0., 0.)
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336
.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):
... stack trace ...
.../jax/jax/interpreters/xla.pyc in handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("invalid value")
    106         else:
    107           return Array(device_buffer, *result_shape)
FloatingPointError: invalid value 

捕获到生成的 NaN。通过运行%debug,我们可以获得后期调试器。正如下面的示例所示,这也适用于在@jit下的函数。

In [4]: from jax import jit
In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
 ... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336
.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):
 ... stack trace ... 

当此代码在 @jit 函数的输出中看到 NaN 时,它调用去优化的代码,因此我们仍然可以获得清晰的堆栈跟踪。我们可以使用 %debug 运行事后调试器来检查所有值,以找出错误。

⚠️ 如果您不是在调试,就不应该开启 NaN 检查器,因为它可能会导致大量设备主机往返和性能回归!

⚠️ NaN 检查器在 pmap 中不起作用。要调试 pmap 代码中的 NaN,可以尝试用 vmap 替换 pmap

🔪 双精度(64 位)

目前,默认情况下,JAX 强制使用单精度数字,以减少 Numpy API 将操作数过度提升为 double 的倾向。这是许多机器学习应用程序的期望行为,但可能会让您感到意外!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype 
/tmp/ipykernel_1227/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) 
dtype('float32') 

要使用双精度数,您需要在启动时设置 jax_enable_x64 配置变量**。

有几种方法可以做到这一点:

  1. 您可以通过设置环境变量 JAX_ENABLE_X64=True 来启用 64 位模式。
  2. 您可以在启动时手动设置 jax_enable_x64 配置标志:
# again, this only works on startup!
import jax
jax.config.update("jax_enable_x64", True) 
  1. 您可以使用 absl.app.run(main) 解析命令行标志
import jax
jax.config.config_with_absl() 
  1. 如果您希望 JAX 为您运行 absl 解析,即您不想执行 absl.app.run(main),您可以改用
import jax
if __name__ == '__main__':
  # calls jax.config.config_with_absl() *and* runs absl parsing
  jax.config.parse_flags_with_absl() 

请注意,#2-#4 适用于任何 JAX 的配置选项。

然后,我们可以确认已启用 x64 模式:

import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64') 
/tmp/ipykernel_1227/2819792939.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) 
dtype('float32') 

注意事项

⚠️ XLA 不支持所有后端的 64 位卷积!

🔪 NumPy 中的各种分歧

虽然 jax.numpy 尽力复制 numpy API 的行为,但确实存在一些边界情况,其行为有所不同。许多这样的情况在前面的部分中有详细讨论;这里我们列出了几个已知的其他 API 分歧处。

  • 对于二进制操作,JAX 的类型提升规则与 NumPy 略有不同。有关更多详细信息,请参阅类型提升语义
  • 在执行不安全类型转换(即目标 dtype 不能表示输入值的转换)时,JAX 的行为可能依赖于后端,并且通常可能与 NumPy 的行为不同。NumPy 允许通过 casting 参数(参见np.ndarray.astype)控制这些情况下的结果;JAX 不提供任何此类配置,而是直接继承XLA:ConvertElementType的行为。
    这是一个示例,显示了在 NumPy 和 JAX 之间存在不同结果的不安全转换:
>>> np.arange(254.0, 258.0).astype('uint8')
array([254, 255,   0,   1], dtype=uint8)
>>> jnp.arange(254.0, 258.0).astype('uint8')
Array([254, 255, 255, 255], dtype=uint8) 
  • 这种不匹配通常在将浮点值转换为整数类型或反之时出现极端情况。

结束。

如果这里没有涉及到您曾经因之而哭泣和咬牙切齿的问题,请告知我们,我们将扩展这些介绍性建议

JAX 常见问题解答(FAQ)

原文:jax.readthedocs.io/en/latest/faq.html

我们在这里收集了一些经常被问到的问题的答案。欢迎贡献!

jit改变了我的函数行为

如果你有一个在使用jax.jit()后改变行为的 Python 函数,也许你的函数使用了全局状态或具有副作用。在下面的代码中,impure_func使用了全局变量y并由于print而具有副作用:

y = 0
# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y
for y in range(3):
  print("Result:", impure_func(y)) 

没有jit时的输出是:

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4 

并且使用jit时:

Inside: 0
Result: 0
Result: 1
Result: 2 

对于jax.jit(),函数在 Python 解释器中执行一次,此时发生Inside打印,并观察到y的第一个值。然后,函数被编译并缓存,以不同的x值多次执行,但y的第一个值相同。

更多阅读:


JAX 中文文档(一)(5)https://developer.aliyun.com/article/1559836

相关文章
|
2天前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
9 1
|
2天前
|
编译器 API 异构计算
JAX 中文文档(一)(2)
JAX 中文文档(一)
8 0
|
2天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
10 0
|
2天前
|
机器学习/深度学习 算法 编译器
JAX 中文文档(二)(3)
JAX 中文文档(二)
8 0
|
2天前
JAX 中文文档(九)(3)
JAX 中文文档(九)
7 0
|
2天前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
8 0
|
2天前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
9 0
|
2天前
|
机器学习/深度学习 索引 Python
JAX 中文文档(四)(1)
JAX 中文文档(四)
7 0
|
2天前
|
存储 机器学习/深度学习 TensorFlow
JAX 中文文档(七)(5)
JAX 中文文档(七)
7 0
|
2天前
|
缓存 PyTorch API
JAX 中文文档(一)(3)
JAX 中文文档(一)
8 0