JAX 中文文档(一)(3)https://developer.aliyun.com/article/1559832
🔪 动态形状
在像jax.jit
、jax.vmap
、jax.grad
等变换中使用的 JAX 代码要求所有输出数组和中间数组具有静态形状:即形状不能依赖于其他数组中的值。
例如,如果您正在实现自己的版本jnp.nansum
,您可能会从以下内容开始:
def nansum(x): mask = ~jnp.isnan(x) # boolean mask selecting non-nan values x_without_nans = x[mask] return x_without_nans.sum()
在 JIT 和其他转换之外,这可以正常工作:
x = jnp.array([1, 2, jnp.nan, 3, 4]) print(nansum(x))
10.0
如果尝试将jax.jit
或另一个转换应用于此函数,则会报错:
jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5]) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
问题在于x_without_nans
的大小取决于x
中的值,这另一种方式说它的大小是动态的。通常在 JAX 中,可以通过其他方式绕过对动态大小数组的需求。例如,在这里可以使用jnp.where
的三参数形式,将 NaN 值替换为零,从而计算出相同的结果,同时避免动态形状:
@jax.jit def nansum_2(x): mask = ~jnp.isnan(x) # boolean mask selecting non-nan values return jnp.where(mask, x, 0).sum() print(nansum_2(x))
10.0
在其他情况下,类似的技巧可以发挥作用,其中动态形状数组出现。
🔪 NaNs
调试 NaNs
如果要追踪你的函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:
- 设置
JAX_DEBUG_NANS=True
环境变量; - 在你的主文件顶部添加
jax.config.update("jax_debug_nans", True)
; - 在你的主文件中添加
jax.config.parse_flags_with_absl()
,然后使用命令行标志设置选项,如--jax_debug_nans=True
;
这将导致 NaN 产生时立即终止计算。打开此选项会在由 XLA 产生的每个浮点类型值上添加 NaN 检查。这意味着对于不在@jit
下的每个基元操作,值将被拉回主机并作为 ndarrays 进行检查。对于在@jit
下的代码,将检查每个@jit
函数的输出,如果存在 NaN,则将以逐个操作的去优化模式重新运行函数,有效地一次移除一个@jit
级别。
可能会出现棘手的情况,比如只在@jit
下出现的 NaN,但在去优化模式下却不会产生。在这种情况下,你会看到警告消息打印出来,但你的代码将继续执行。
如果在梯度评估的反向传递中产生 NaNs,当在堆栈跟踪中引发异常时,您将位于 backward_pass 函数中,这本质上是一个简单的 jaxpr 解释器,以反向遍历原始操作序列。在下面的示例中,我们使用命令行env JAX_DEBUG_NANS=True ipython
启动了一个 ipython repl,然后运行了以下命令:
In [1]: import jax.numpy as jnp In [2]: jnp.divide(0., 0.) --------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) <ipython-input-2-f2e2c413b437> in <module>() ----> 1 jnp.divide(0., 0.) .../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2) 343 return floor_divide(x1, x2) 344 else: --> 345 return true_divide(x1, x2) 346 347 .../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2) 332 x1, x2 = _promote_shapes(x1, x2) 333 return lax.div(lax.convert_element_type(x1, result_dtype), --> 334 lax.convert_element_type(x2, result_dtype)) 335 336 .../jax/jax/lax.pyc in div(x, y) 244 def div(x, y): 245 r"""Elementwise division: :math:`x \over y`.""" --> 246 return div_p.bind(x, y) 247 248 def rem(x, y): ... stack trace ... .../jax/jax/interpreters/xla.pyc in handle_result(device_buffer) 103 py_val = device_buffer.to_py() 104 if np.any(np.isnan(py_val)): --> 105 raise FloatingPointError("invalid value") 106 else: 107 return Array(device_buffer, *result_shape) FloatingPointError: invalid value
捕获到生成的 NaN。通过运行%debug
,我们可以获得后期调试器。正如下面的示例所示,这也适用于在@jit
下的函数。
In [4]: from jax import jit In [5]: @jit ...: def f(x, y): ...: a = x * y ...: b = (x + y) / (x - y) ...: c = a + 2 ...: return a + b * c ...: In [6]: x = jnp.array([2., 0.]) In [7]: y = jnp.array([3., 0.]) In [8]: f(x, y) Invalid value encountered in the output of a jit function. Calling the de-optimized version. --------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) <ipython-input-8-811b7ddb3300> in <module>() ----> 1 f(x, y) ... stack trace ... <ipython-input-5-619b39acbaac> in f(x, y) 2 def f(x, y): 3 a = x * y ----> 4 b = (x + y) / (x - y) 5 c = a + 2 6 return a + b * c .../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2) 343 return floor_divide(x1, x2) 344 else: --> 345 return true_divide(x1, x2) 346 347 .../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2) 332 x1, x2 = _promote_shapes(x1, x2) 333 return lax.div(lax.convert_element_type(x1, result_dtype), --> 334 lax.convert_element_type(x2, result_dtype)) 335 336 .../jax/jax/lax.pyc in div(x, y) 244 def div(x, y): 245 r"""Elementwise division: :math:`x \over y`.""" --> 246 return div_p.bind(x, y) 247 248 def rem(x, y): ... stack trace ...
当此代码在 @jit
函数的输出中看到 NaN 时,它调用去优化的代码,因此我们仍然可以获得清晰的堆栈跟踪。我们可以使用 %debug
运行事后调试器来检查所有值,以找出错误。
⚠️ 如果您不是在调试,就不应该开启 NaN 检查器,因为它可能会导致大量设备主机往返和性能回归!
⚠️ NaN 检查器在 pmap
中不起作用。要调试 pmap
代码中的 NaN,可以尝试用 vmap
替换 pmap
。
🔪 双精度(64 位)
目前,默认情况下,JAX 强制使用单精度数字,以减少 Numpy API 将操作数过度提升为 double
的倾向。这是许多机器学习应用程序的期望行为,但可能会让您感到意外!
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype
/tmp/ipykernel_1227/1258726447.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
要使用双精度数,您需要在启动时设置 jax_enable_x64
配置变量**。
有几种方法可以做到这一点:
- 您可以通过设置环境变量
JAX_ENABLE_X64=True
来启用 64 位模式。 - 您可以在启动时手动设置
jax_enable_x64
配置标志:
# again, this only works on startup! import jax jax.config.update("jax_enable_x64", True)
- 您可以使用
absl.app.run(main)
解析命令行标志
import jax jax.config.config_with_absl()
- 如果您希望 JAX 为您运行 absl 解析,即您不想执行
absl.app.run(main)
,您可以改用
import jax if __name__ == '__main__': # calls jax.config.config_with_absl() *and* runs absl parsing jax.config.parse_flags_with_absl()
请注意,#2-#4 适用于任何 JAX 的配置选项。
然后,我们可以确认已启用 x64
模式:
import jax.numpy as jnp from jax import random x = random.uniform(random.key(0), (1000,), dtype=jnp.float64) x.dtype # --> dtype('float64')
/tmp/ipykernel_1227/2819792939.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32\. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')
注意事项
⚠️ XLA 不支持所有后端的 64 位卷积!
🔪 NumPy 中的各种分歧
虽然 jax.numpy
尽力复制 numpy API 的行为,但确实存在一些边界情况,其行为有所不同。许多这样的情况在前面的部分中有详细讨论;这里我们列出了几个已知的其他 API 分歧处。
- 对于二进制操作,JAX 的类型提升规则与 NumPy 略有不同。有关更多详细信息,请参阅类型提升语义。
- 在执行不安全类型转换(即目标 dtype 不能表示输入值的转换)时,JAX 的行为可能依赖于后端,并且通常可能与 NumPy 的行为不同。NumPy 允许通过
casting
参数(参见np.ndarray.astype
)控制这些情况下的结果;JAX 不提供任何此类配置,而是直接继承XLA:ConvertElementType的行为。
这是一个示例,显示了在 NumPy 和 JAX 之间存在不同结果的不安全转换:
>>> np.arange(254.0, 258.0).astype('uint8') array([254, 255, 0, 1], dtype=uint8) >>> jnp.arange(254.0, 258.0).astype('uint8') Array([254, 255, 255, 255], dtype=uint8)
- 这种不匹配通常在将浮点值转换为整数类型或反之时出现极端情况。
结束。
如果这里没有涉及到您曾经因之而哭泣和咬牙切齿的问题,请告知我们,我们将扩展这些介绍性建议!
JAX 常见问题解答(FAQ)
我们在这里收集了一些经常被问到的问题的答案。欢迎贡献!
jit
改变了我的函数行为
如果你有一个在使用jax.jit()
后改变行为的 Python 函数,也许你的函数使用了全局状态或具有副作用。在下面的代码中,impure_func
使用了全局变量y
并由于print
而具有副作用:
y = 0 # @jit # Different behavior with jit def impure_func(x): print("Inside:", y) return x + y for y in range(3): print("Result:", impure_func(y))
没有jit
时的输出是:
Inside: 0 Result: 0 Inside: 1 Result: 2 Inside: 2 Result: 4
并且使用jit
时:
Inside: 0 Result: 0 Result: 1 Result: 2
对于jax.jit()
,函数在 Python 解释器中执行一次,此时发生Inside
打印,并观察到y
的第一个值。然后,函数被编译并缓存,以不同的x
值多次执行,但y
的第一个值相同。
更多阅读:
JAX 中文文档(一)(5)https://developer.aliyun.com/article/1559836