JAX 中文文档(五)(2)https://developer.aliyun.com/article/1559809
要了解关于追踪器与常规值、具体值与抽象值的更多微妙之处,可以阅读有关不同类型的 JAX 值。
参数:
tracer (core.Tracer)
class jax.errors.TracerBoolConversionError(tracer)
当在期望布尔值的上下文中使用 JAX 中的追踪值时会出现此错误(详见不同类型的 JAX 值,了解追踪器的更多信息)。
布尔转换可以是显式的(例如bool(x)
)或隐式的,通过控制流的使用(例如if x > 0
或while x
)、使用 Python 布尔运算符(例如z = x and y
、z = x or y
、z = not x
)或使用它们的函数(例如z = max(x, y)
、z = min(x, y)
等)。
在某些情况下,通过将跟踪值标记为静态,可以轻松解决此问题;在其他情况下,这可能表明您的程序正在执行 JAX JIT 编译模型不直接支持的操作。
示例:
在控制流中使用跟踪值
一个经常出现这种情况的案例是,当跟踪值用于 Python 控制流时。例如:
>>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, y): ... return x if x.sum() < y.sum() else y >>> func(jnp.ones(4), jnp.zeros(4)) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
我们可以将输入的x
和y
都标记为静态,但这样做将破坏在这里使用jax.jit()
的目的。另一个选择是将 if 语句重新表达为三项jax.numpy.where()
:
>>> @jit ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)
对于包括循环在内的更复杂的控制流,请参阅控制流运算符。
跟踪值在控制流中的使用
另一个常见的错误原因是,如果您无意中在布尔标志上进行跟踪。例如:
>>> @jit ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这里,因为标志normalize
被跟踪,所以不能在 Python 控制流中使用它。在这种情况下,最好的解决方案可能是将此值标记为静态:
>>> from functools import partial >>> @partial(jit, static_argnames=['normalize']) ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Array([0\. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
有关static_argnums
的更多信息,请参阅jax.jit()
的文档。
使用非 JAX 感知的函数
另一个常见的错误原因是在 JAX 代码中使用非 JAX 感知的函数。例如:
>>> @jit ... def func(x): ... return min(x, 0)
>>> func(2) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这种情况下,错误是因为 Python 的内置min
函数与 JAX 变换不兼容。可以通过将其替换为jnp.minimum
来修复这个问题:
>>> @jit ... def func(x): ... return jnp.minimum(x, 0)
>>> print(func(2)) 0
要更深入了解关于跟踪器与常规值、具体值与抽象值之间的微妙差别,您可能需要阅读关于不同类型 JAX 值的文档。
参数:
tracer(core.Tracer)
class jax.errors.TracerIntegerConversionError(tracer)
如果在期望 Python 整数的上下文中使用 JAX Tracer 对象,则可能会出现此错误(有关 Tracer 是什么的更多信息,请参阅关于不同类型 JAX 值的内容)。它通常发生在几种情况下。
将跟踪器放在整数位置
如果您试图将跟踪值传递给需要静态整数参数的函数,则可能会出现此错误;例如:
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(4), 0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
当出现这种情况时,解决方案通常是将有问题的参数标记为静态:
>>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) [Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]
另一种方法是将转换应用于封装要保护参数的闭包,可以手动执行如下或使用functools.partial()
:
>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
请注意,每次调用都会创建一个新的闭包,这会破坏编译缓存机制,这也是为什么首选static_argnums
的原因。
使用跟踪器索引列表
如果您尝试使用跟踪的量索引 Python 列表,则可能会出现此错误。例如:
>>> import jax.numpy as jnp >>> from jax import jit >>> L = [1, 2, 3] >>> @jit ... def func(i): ... return L[i] >>> func(0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
根据上下文,通常可以通过将列表转换为 JAX 数组来解决此问题:
>>> @jit ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)
或者通过将索引声明为静态参数来声明:
>>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i] >>> func(0) Array(1, dtype=int32, weak_type=True)
要更深入理解跟踪器与常规值以及具体与抽象值之间的微妙差别,您可以阅读有关不同类型 JAX 值的文档。
参数:
tracer(core.Tracer)
class jax.errors.UnexpectedTracerError(msg)
当您使用从函数中泄漏出来的 JAX 值时,会出现此错误。泄漏值是什么意思?如果您对函数f
应用 JAX 转换,并在f
外某个作用域存储了一个中间值的引用,那么该值被视为已泄漏。泄漏值是副作用。(阅读更多关于避免副作用的内容,请参阅Pure Functions)
JAX 在你稍后在另一个操作中使用泄露的值时检测到泄漏,此时会引发UnexpectedTracerError
。要修复此问题,请避免副作用:如果一个函数计算了外部作用域需要的值,则需要明确从转换后的函数中返回该值。
具体来说,Tracer
是 JAX 在转换期间函数中间值的内部表示,例如在jit()
、pmap()
、vmap()
等内部。在转换之外遇到Tracer
表示泄漏。
泄漏值的生命周期
请考虑以下转换函数的示例,它将一个值泄漏到外部作用域:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit # 1 ... def side_effecting(x): ... y = x + 1 # 3 ... outs.append(y) # 4 >>> x = 1 >>> side_effecting(x) # 2 >>> outs[0] + 1 # 5 Traceback (most recent call last): ... UnexpectedTracerError: Encountered an unexpected tracer.
在此示例中,我们从内部转换作用域泄漏了一个跟踪值到外部作用域。当使用泄漏值而不是泄漏值时,会出现UnexpectedTracerError
。
此示例还展示了泄漏值的生命周期:
- 函数被转换了(在本例中,通过
jit()
)。- 调用了转换后的函数(启动函数的抽象跟踪,并将
x
转换为Tracer
)。- 中间值
y
被创建,稍后将被泄漏(跟踪函数的中间值也是Tracer
)。- 该值已泄漏(通过外部作用域的一个侧通道将其追加到列表中逃逸函数)
- 使用了泄漏的值,并引发了 UnexpectedTracerError。
UnexpectedTracerError 消息试图通过包含有关每个阶段信息的方法来指出代码中的这些位置。依次:
- 转换后函数的名称(
side_effecting
)以及触发跟踪的转换名称jit()
)。- 泄漏的 Tracer 创建时的重构堆栈跟踪,包括调用转换后函数的位置。(
When the Tracer was created, the final 5 stack frames were...
)。- 从重构的堆栈跟踪中,创建泄漏 Tracer 的代码行。
- 错误消息中不包括泄漏位置,因为难以确定!JAX 只能告诉你泄漏值的外观(其形状和创建位置)以及泄漏的边界(变换的名称和转换后函数的名称)。
- 当前错误的堆栈跟踪指向值的使用位置。
可以通过将值从转换函数返回来修复错误:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def not_side_effecting(x): ... y = x+1 ... return y >>> x = 1 >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. Array(3, dtype=int32, weak_type=True)
泄漏检查器
如上述第 2 和第 3 点所讨论的那样,JAX 显示了一个重建的堆栈跟踪,指出了泄露值的创建位置。这是因为 JAX 仅在使用泄露值时才会引发错误,而不是在值泄漏时。这不是引发此错误的最有用的地方,因为您需要知道泄露跟踪器的位置来修复错误。
为了更容易跟踪此位置,您可以使用泄漏检查器。当启用泄漏检查器时,一旦泄露了Tracer
,就会引发错误。(更确切地说,在从中泄漏Tracer
的转换函数返回时会引发错误)
要启用泄漏检查器,可以使用JAX_CHECK_TRACER_LEAKS
环境变量或with jax.checking_leaks()
上下文管理器。
注意
请注意,此工具属于实验性质,可能会报告错误的情况。它通过禁用某些 JAX 缓存工作,因此会对性能产生负面影响,应仅在调试时使用。
示例用法:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def side_effecting(x): ... y = x+1 ... outs.append(y) >>> x = 1 >>> with jax.checking_leaks(): ... y = side_effecting(x) Traceback (most recent call last): ... Exception: Leaked Trace
参数:
msg (str)
转移保护
JAX 可能在类型转换和输入分片期间在主机和设备之间传输数据。为了记录或阻止任何意外的转移,用户可以配置 JAX 转移保护。
JAX 转移保护区分两种类型的转移:
- 显式转移:
jax.device_put*()
和jax.device_get()
调用。 - 隐式转移:其他转移(例如打印
DeviceArray
)。
转移保护可以根据其保护级别采取行动:
"allow"
: 静默允许所有转移(默认)。"log"
: 记录并允许隐式转移。静默允许显式转移。"disallow"
: 禁止隐式转移。静默允许显式转移。"log_explicit"
: 记录并允许所有转移。"disallow_explicit"
: 禁止所有转移。
当禁止转移时,JAX 将引发 RuntimeError
。
转移保护使用标准的 JAX 配置系统:
- 一个
--jax_transfer_guard=GUARD_LEVEL
命令行标志和jax.config.update("jax_transfer_guard", GUARD_LEVEL)
将设置全局选项。 - 一个
with jax.transfer_guard(GUARD_LEVEL): ...
上下文管理器将在上下文管理器的作用域内设置线程局部选项。
注意,类似于其他 JAX 配置选项,新生成的线程将使用全局选项,而不是生成线程所在作用域的任何活动线程局部选项。
转移保护还可以根据转移方向更为选择性地应用。标志和上下文管理器名称以相应的转移方向作为后缀(例如 --jax_transfer_guard_host_to_device
和 jax.config.transfer_guard_host_to_device
):
"host_to_device"
: 将 Python 值或 NumPy 数组转换为 JAX 设备上的缓冲区。"device_to_device"
: 将 JAX 设备缓冲区复制到另一个设备。"device_to_host"
: 从 JAX 设备缓冲区获取数据。
获取 CPU 设备上的缓冲区始终允许,无论转移保护级别如何。
下面展示了使用转移保护的示例。
>>> jax.config.update("jax_transfer_guard", "allow") # This is default. >>> >>> x = jnp.array(1) >>> y = jnp.array(2) >>> z = jnp.array(3) >>> >>> print("x", x) # All transfers are allowed. x 1 >>> with jax.transfer_guard("disallow"): ... print("x", x) # x has already been fetched into the host. ... print("y", jax.device_get(y)) # Explicit transfers are allowed. ... try: ... print("z", z) # Implicit transfers are disallowed. ... assert False, "This line is expected to be unreachable." ... except: ... print("z could not be fetched") x 1 y 2 z could not be fetched
Pallas:一个 JAX 内核语言
Pallas 是 JAX 的扩展,允许为 GPU 和 TPU 编写自定义内核。本节包含使用 Pallas 的教程、指南和示例。
指南
- Pallas 设计
- 介绍
- Pallas:为内核扩展 JAX
- Pallas 快速入门
- 在 Pallas 中的 Hello world
- Pallas 编程模型
平台特性
- Pallas TPU
- 使用 Pallas 编写 TPU 内核
- 流水线和
BlockSpec
JAX 中文文档(五)(4)https://developer.aliyun.com/article/1559812