JAX 中文文档(十一)(2)https://developer.aliyun.com/article/1559780
当全阶段打开时可能会出现哪些问题?
当在jit
或pmap
的动态上下文中,从 Python 到 XLA 分阶段所有jax.numpy
操作的结果,一些之前正常工作的代码可能会开始引发大声的错误。正如下文所解释的那样,这些行为在全阶段之前已经存在 bug,但全阶段将它们变成了严格的错误。
使用jax.numpy
进行形状计算
示例
from jax import jit import jax.numpy as jnp @jit def ex1(x): size = jnp.prod(jnp.array(x.shape)) return x.reshape((size,)) ex1(jnp.ones((3, 4)))
错误消息
[... full traceback ...] File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error raise ConcretizationTypeError(msg) jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected. The error arose in jax.numpy.reshape. While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines: operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2] from line ex1.py:6 (ex1) You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions. See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information. Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
解释
在全面化下,我们不能像上面使用jnp.prod
一样在 jit 函数的动态上下文中使用jax.numpy
进行形状计算,因为这些操作将被分阶段为在执行时计算的值,但我们需要它们是编译时常量(因此是跟踪时常量)。
在全面化之前,这段代码不会引发错误,但这是一个常见的性能 bug:jnp.prod
计算将在跟踪时间在设备上执行,意味着额外的编译、传输、同步、分配和潜在的内存碎片化。
解决方案
解决方法很简单,就是像这样的形状计算使用原始的numpy
。这不仅避免了错误,还将计算保持在主机上(并且开销更低)。
在代码中,这个问题很常见,我们努力使错误消息尤其好。除了堆栈跟踪显示抽象跟踪器值导致问题的位置(完整堆栈跟踪中的jnp.reshape
行,在 omni.py:10),我们还解释了这个值首先变成跟踪器的原因,指向导致它成为抽象跟踪器的上游原始操作(来自jnp.prod
中的reduce_prod
,在 omni.py:9),以及跟踪器属于哪个带jit
装饰的函数(在 omni.py:6 中的ex1
)。
副作用
示例
from jax import jit from jax import random key = random.PRNGKey(0) def init(): global key key, subkey = random.split(key) return random.normal(subkey, ()) print(init()) # -1.2515389 print(init()) # -0.58665067 init = jit(init) print(init()) # 0.48648298 print(init()) # 0.48648298 !!
最后一个调用具有重复的随机性,但没有硬错误,因为我们没有重新执行 Python。但是如果我们查看key
,我们会看到一个逃逸的跟踪器开启全面化时:
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
在全面化之前,random.split
调用不会被分阶段处理,因此我们不会得到逃逸的跟踪器。由于重复使用相同的 PRNG 密钥,代码仍然存在 bug,即编译函数无法复制原始函数的语义(因为有副作用)。
在开启全面化时,如果再次触及key
,将会得到一个逃逸的跟踪器错误:
random.normal(key, ())
错误消息
[... full stack trace …] File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live raise core.escaped_tracer_error(msg) jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function. The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
解释
我们发现的次大类全面化问题与副作用代码有关。这些代码通过转换有副作用的函数已经使 JAX 的保证失效,但由于预全面化的“跟踪时间常数折叠”行为,一些有副作用的函数仍然可能表现正确。全面化能更多地捕捉这些错误。
解决方案
解决方法是识别依赖副作用的 JAX 转换函数,并重新编写它们以避免有副作用。
基于 XLA 优化的小数值差异
因为在全面化下,更多的计算被分阶段到 XLA,而不是在跟踪时间执行,这可能导致浮点运算的重新排序。结果是,我们看到数值行为以一种导致测试在开启全面化时失败的方式改变,因为它们对于过紧容差的测试失败。
依赖于 JAX 内部 API 的变化
Omnistaging 涉及对 JAX 核心代码进行了一些重大修改,包括删除或更改内部函数。任何依赖这些内部 JAX API 的代码,在 omnistaging 打开时都可能会出现问题,可能是构建错误(来自 pytype)或运行时错误。
触发 XLA 编译时错误
由于 omnistaging 涉及将更多代码分阶段传递给 XLA,我们发现它可能会在某些后端触发现有的 XLA 编译时错误。对于这些问题,最好的做法是报告它们,以便我们与 XLA 团队合作进行修复。
JEP 9263:类型化密钥和可插拔的 RNG
Jake VanderPlas, Roy Frostig
August 2023
概述
未来,在 JAX 中,RNG 密钥将更加类型安全和可定制。 不再通过长度为 2 的uint32
数组表示单个 PRNG 密钥,而是通过一个标量数组表示,该数组具有满足jnp.issubdtype(key.dtype, jax.dtypes.prng_key)
的特殊 RNG dtype。
目前,可以使用jax.random.PRNGKey()
仍然创建旧样式的 RNG 密钥:
>>> key = jax.random.PRNGKey(0) >>> key Array([0, 0], dtype=uint32) >>> key.shape (2,) >>> key.dtype dtype('uint32')
从现在开始,可以使用jax.random.key()
创建新样式的 RNG 密钥:
>>> key = jax.random.key(0) >>> key Array((), dtype=key<fry>) overlaying: [0 0] >>> key.shape () >>> key.dtype key<fry>
这个(标量形状的)数组的行为与任何其他 JAX 数组相同,只是其元素类型是一个密钥(及其关联的元数据)。 我们也可以制作非标量密钥数组,例如通过将jax.vmap()
应用于jax.random.key()
:
>>> key_arr = jax.vmap(jax.random.key)(jnp.arange(4)) >>> key_arr Array((4,), dtype=key<fry>) overlaying: [[0 0] [0 1] [0 2] [0 3]] >>> key_arr.shape (4,)
除了切换到新的构造函数外,大多数与 PRNG 相关的代码应该继续按预期工作。 您可以像以前一样继续使用jax.random
API 中的密钥;例如:
# split new_key, subkey = jax.random.split(key) # random number generation data = jax.random.uniform(key, shape=(5,))
然而,并非所有数值操作都适用于密钥数组。 它们现在故意引发错误:
>>> key = key + 1 Traceback (most recent call last): TypeError: add does not accept dtypes key<fry>, int32.
如果出于某种原因您需要恢复底层缓冲区(旧样式密钥),您可以使用jax.random.key_data()
来实现:
>>> jax.random.key_data(key) Array([0, 0], dtype=uint32)
对于旧样式密钥,key_data()
是一个身份操作。
对用户来说,这意味着什么?
对于 JAX 用户,这种变化现在不需要任何代码更改,但我们希望您会发现升级是值得的,并切换到使用类型化密钥。 要尝试这个功能,请将使用jax.random.PRNGKey()
替换为jax.random.key()
。 这可能会在您的代码中引入一些破坏性变化,属于以下几类之一:
- 如果您的代码对密钥执行不安全/不支持的操作(如索引、算术运算、转置等;请参阅下面的类型安全部分),这种变化将捕捉到它。 您可以更新您的代码以避免此类不支持的操作,或者使用
jax.random.key_data()
和jax.random.wrap_key_data()
以不安全的方式操作原始密钥缓冲区。 - 如果您的代码包含关于
key.shape
的显式逻辑,您可能需要更新此逻辑以考虑尾部密钥缓冲区维度不再是形状的显式部分。 - 如果您的代码包含关于
key.dtype
的显式逻辑,您需要将其升级为使用新的公共 API 来推理 RNG dtypes,例如dtypes.issubdtype(dtype, dtypes.prng_key)
。 - 如果您调用一个尚未处理类型化 PRNG 密钥的基于 JAX 的库,您现在可以使用
raw_key = jax.random.key_data(key)
来恢复原始缓冲区,但请务必保留一个 TODO 来在下游库支持类型化 RNG 密钥后移除此操作。
在未来的某个时候,我们计划废弃jax.random.PRNGKey()
并要求使用jax.random.key()
。
检测新样式的类型化密钥
要检查对象是否为新样式的类型化 PRNG 密钥,可以使用jax.dtypes.issubdtype
或jax.numpy.issubdtype
:
>>> typed_key = jax.random.key(0) >>> jax.dtypes.issubdtype(typed_key.dtype, jax.dtypes.prng_key) True >>> raw_key = jax.random.PRNGKey(0) >>> jax.dtypes.issubdtype(raw_key.dtype, jax.dtypes.prng_key) False
PRNG 密钥的类型注释
旧式和新式 PRNG 密钥的推荐类型注释是 jax.Array
。PRNG 密钥根据其dtype
与其他数组区分开来,目前无法在类型注释中指定 JAX 数组的 dtype。以前可以使用jax.random.KeyArray
或jax.random.PRNGKeyArray
作为类型注释,但在类型检查下始终被别名为Any
,因此jax.Array
具有更高的特异性。
注:在 JAX 版本 0.4.16 中,jax.random.KeyArray
和 jax.random.PRNGKeyArray
已弃用,并在 JAX 版本 0.4.24 中移除。
JAX 库作者注意事项
如果您维护基于 JAX 的库,您的用户也是 JAX 用户。请知道 JAX 将继续支持“原始”旧式密钥在jax.random
中,因此调用者可能期望它们在所有地方都被接受。如果您希望在您的库中要求新式类型化密钥,则可能希望使用以下方式进行检查以强制执行它们:
from jax import dtypes def ensure_typed_key_array(key: Array) -> Array: if dtypes.issubdtype(key.dtype, dtypes.prng_key): return key else: raise TypeError("New-style typed JAX PRNG keys required")
动机
此更改的两个主要动机因素是可定制性和安全性。
自定义 PRNG 实现
JAX 目前使用单一的全局配置 PRNG 算法。PRNG 密钥是无符号 32 位整数的向量,jax.random API 使用它们生成伪随机流。任何更高秩的 uint32 数组都被解释为具有这些密钥缓冲区的数组,其中尾部维度表示密钥。
这种设计的缺点在我们引入替代的伪随机数生成器(PRNG)实现时变得更加明显,这些实现必须通过设置全局或本地配置标志来选择。不同的 PRNG 实现具有不同大小的密钥缓冲区和生成随机比特的不同算法。通过全局标志确定此行为容易出错,特别是在整个进程中使用多个密钥实现时。
我们的新方法是将实现作为 PRNG 密钥类型的一部分,即密钥数组的元素类型。使用新的密钥 API,下面是在默认的 threefry2x32 实现(纯 Python 实现,并与 JAX 编译)和非默认的 rbg 实现(对应单个 XLA 随机比特生成操作)下生成伪随机值的示例:
>>> key = jax.random.key(0, impl='threefry2x32') # this is the default impl >>> key Array((), dtype=key<fry>) overlaying: [0 0] >>> jax.random.uniform(key, shape=(3,)) Array([0.9653214 , 0.31468165, 0.63302994], dtype=float32) >>> key = jax.random.key(0, impl='rbg') >>> key Array((), dtype=key<rbg>) overlaying: [0 0 0 0] >>> jax.random.uniform(key, shape=(3,)) Array([0.39904642, 0.8805201 , 0.73571277], dtype=float32)
安全的 PRNG 密钥使用
原则上,PRNG 密钥确实只支持少数几种操作,即密钥衍生(例如拆分)和随机数生成。只要正确拆分密钥并且每个密钥只使用一次,PRNG 就设计为生成独立的伪随机数。
在其他方式中操作或消耗密钥数据的代码通常表明是意外的错误,将密钥数组表示为原始 uint32 缓冲区已经允许沿着这些方向容易发生误用。以下是我们在实际使用中遇到的几个示例错误用法:
密钥缓冲区索引
访问底层整数缓冲区使得可以轻松尝试以非标准方式导出密钥,有时会带来意想不到的不良后果:
# Incorrect key = random.PRNGKey(999) new_key = random.PRNGKey(key[1]) # identical to the original key!
# Correct key = random.PRNGKey(999) key, new_key = random.split(key)
如果此关键是使用random.key(999)
创建的新型类型化关键,则索引到关键缓冲区将会出错。
关键算术
关键算术是从其他关键派生关键的一种类似险恶的方式。通过直接操作关键数据而避免jax.random.split()
或jax.random.fold_in()
来派生关键,会产生一批关键,这些关键——根据 PRNG 实现——可能会在批次内生成相关的随机数:
# Incorrect key = random.PRNGKey(0) batched_keys = key + jnp.arange(10, dtype=key.dtype)[:, None]
# Correct key = random.PRNGKey(0) batched_keys = random.split(key, 10)
使用random.key(0)
创建的新型类型化关键通过禁止对关键进行算术操作来解决这个问题。
意外转置关键缓冲区
使用“原始”旧式关键数组时,很容易意外交换批次(前导)维度和关键缓冲区(尾随)维度。再次可能导致产生相关伪随机性的关键。多年来我们见过的一个模式归结如下:
# Incorrect keys = random.split(random.PRNGKey(0)) data = jax.vmap(random.uniform, in_axes=1)(keys)
# Correct keys = random.split(random.PRNGKey(0)) data = jax.vmap(random.uniform, in_axes=0)(keys)
这里的 bug 很微妙。通过在 in_axes=1
上映射,此代码通过将批次中每个关键缓冲区的单个元素组合成新关键来生成新关键。生成的关键彼此不同,但实质上以非标准方式“派生”。再次强调,PRNG 并未设计或测试以从这样的关键批次生成独立的随机流。
使用random.key(0)
创建的新型类型化关键通过隐藏个体关键的缓冲区表示,而将关键视为关键数组的不透明元素来解决这个问题。关键数组没有尾随的“缓冲区”维度可以索引、转置或映射。
关键重用
不像像numpy.random
这样的基于状态的 PRNG API,JAX 的函数式 PRNG 在使用后不会隐式更新关键。
# Incorrect key = random.PRNGKey(0) x = random.uniform(key, (100,)) y = random.uniform(key, (100,)) # Identical values!
# Correct key = random.PRNGKey(0) key1, key2 = random.split(random.key(0)) x = random.uniform(key1, (100,)) y = random.uniform(key2, (100,))
我们正在积极开发工具来检测和防止意外的关键重用。这仍然是一个正在进行中的工作,但它依赖于类型化关键数组。现在升级到类型化关键使我们能够在构建这些安全功能时引入它们。
类型化 PRNG 关键的设计
类型化 PRNG 关键在 JAX 中实现为扩展 dtypes 的实例,其中新的 PRNG dtypes 是子 dtype。
扩展 dtypes
从用户角度来看,扩展 dtype dt 具有以下用户可见属性:
jax.dtypes.issubdtype(dt, jax.dtypes.extended)
返回True
:这是应该用于检测 dtype 是否为扩展 dtype 的公共 API。- 它具有类级属性
dt.type
,返回在numpy.generic
层次结构中的类型类。这类似于np.dtype('int32').type
返回numpy.int32
,这不是 dtype 而是标量类型,并且是numpy.generic
的子类。 - 与 numpy 标量类型不同,我们不允许实例化
dt.type
标量对象:这符合 JAX 将标量值表示为零维数组的决定。
从非公开实现的角度来看,扩展 dtype 具有以下属性:
- 它的类型是私有基类
jax._src.dtypes.ExtendedDtype
的子类,这是用于扩展数据类型的非公开基类。ExtendedDtype
的实例类似于np.dtype
的实例,例如np.dtype('int32')
。 - 它具有私有的
_rules
属性,允许数据类型定义在特定操作下的行为方式。例如,当dtype
是扩展数据类型时,jax.lax.full(shape, fill_value, dtype)
将委托给dtype._rules.full(shape, fill_value, dtype)
。
为什么要在一般情况下引入扩展数据类型,超出了伪随机数生成器的范围?我们在内部的其他地方重复使用同样的扩展数据类型机制。例如,jax._src.core.bint
对象是另一种扩展数据类型,用于动态形状的实验工作。在最近的 JAX 版本中,它满足上述属性(见jax/_src/core.py#L1789-L1802)。
PRNG 数据类型
PRNG 数据类型被定义为扩展数据类型的特例。具体来说,此更改引入了一个新的公共标量类型类jax.dtypes.prng_key
,其具有以下属性:
>>> jax.dtypes.issubdtype(jax.dtypes.prng_key, jax.dtypes.extended) True
PRNG 密钥数组然后具有以下属性的数据类型:
>>> key = jax.random.key(0) >>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.extended) True >>> jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key) True
除了一般情况下扩展数据类型的key.dtype._rules
,PRNG 数据类型定义了key.dtype._impl
,其中包含定义 PRNG 实现的元数据。当前,PRNGImpl
并不打算成为公共 API,但我们可能很快会重新审视这一点,以允许完全自定义的 PRNG 实现。
进展
以下是实施上述设计的关键拉取请求的非全面列表。主要的跟踪问题是#9263。
- 通过
PRNGImpl
实现可插拔 PRNG:#6899 - 实现
PRNGKeyArray
,不包括数据类型:#11952 - 向
PRNGKeyArray
添加一个“自定义元素”数据类型属性,具有_rules
属性:#12167 - 将“自定义元素类型”重命名为“不透明数据类型”:#12170
- 重构
bint
以使用不透明数据类型基础设施:#12707 - 添加
jax.random.key
以直接创建带类型的密钥:#16086 - 为
key
和PRNGKey
添加impl
参数:#16589 - 将“不透明数据类型”重命名为“扩展数据类型”,并定义
jax.dtypes.extended
:#16824 - 引入
jax.dtypes.prng_key
并统一 PRNG 数据类型和扩展数据类型:#16781 - 添加一个
jax_legacy_prng_key
标志,以支持在使用传统(原始)PRNG 密钥时发出警告或错误:#17225
JAX 中文文档(十一)(4)https://developer.aliyun.com/article/1559783