JAX 增强提案(JEPs)
大多数改动可以通过简单的问题/讨论和拉取请求进行讨论。
然而,有些变更范围较大或需要更多讨论,这些应该作为 JEP 实现。这允许编写更长的文档,可以在拉取请求中进行讨论。
JEP 的结构被保持尽可能轻量以便开始,并可能在以后进行扩展。
当你需要使用一个 JEP 时
- 当你的改动需要一个设计文档时。我们更倾向于将设计收集为 JEPs,以便更好地发现和进一步参考。
- 当你的改动需要广泛讨论时。对于问题或拉取请求进行相对较短的讨论是可以接受的,但是当讨论变得更长时,这对于后续消化就变得不实际了。JEP 允许通过更新主文档添加讨论摘要,并且这些更新本身可以在添加 JEP 的拉取请求中进行讨论。
如何开始一个 JEP
首先,创建一个带有JEP 标签的问题。所有与 JEP 相关的拉取请求(即添加 JEP 本身以及任何实现拉取请求)都应链接到此问题。
然后创建一个添加名为%d-{short-title}.md 文件的拉取请求 - 其中数字是问题号。
- 263: JAX PRNG 设计
- 2026: JAX 可转换函数的自定义 JVP/VJP 规则
- 4008: 自定义 VJP 和
nondiff_argnums
更新 - 4410: Omnistaging
- 9263: 类型键和可插拔 RNG 的设计
- 9407: JAX 类型提升语义设计
- 9419: Jax 和 Jaxlib 版本控制
- 10657: JAX 中的顺序副作用
- 11830:
jax.remat
/jax.checkpoint
的新实现 - 12049: JAX 的类型注解路线图
- 14273:
shard_map
(shmap
)用于简单的每设备代码 - 15856:
jax.extend
,一个扩展模块 - 17111:
shard_map
(和其他映射)的高效转置 - 18137: JAX NumPy 和 SciPy 包装器的范围
一些早期的 JEP 实际上是从其他文档、问题和拉取请求后来转换而来的,因此它们可能不完全符合上述过程。
JAX PRNG 设计
我们希望一个 PRNG 设计
- 是表达力强的,因为它方便使用,并且不会限制用户编写具有精确所需行为的数值程序的能力,
- 以一种与后端无关的方式启用可复现的程序执行,
- 具有对
@jit
编译边界和设备后端不变的语义, - 使用 SIMD 硬件启用向量化以生成数组值,
- 是可并行化的,因为它不会在随机函数调用之间添加顺序约束,否则这些调用没有数据依赖,
- 能够扩展到多副本、多核和分布式计算,
- 与 JAX 和 XLA 的语义和设计哲学契合(这些哲学最终是由其他实际问题驱动的)。
作为这些的必然结果,我们认为设计应该是功能性的。另一个推论是,至少在当前硬件约束条件下,我们将在软件中进行 PRNG。
TLDR JAX PRNG = Threefry counter PRNG + 一个功能性数组导向的分裂模型
内容
- 三种编程模型和玩具示例程序
- 设计
- 更现实的用户示例程序
- 折衷和替代方案
三种编程模型和玩具示例程序
这里是一个类似于在 Numpy 程序中经常使用的有状态全局PRNG 的玩具示例:
def foo(): return bar() + baz() def bar(): return rand(RNG, (3, 4)) def baz(): return rand(RNG, (3, 4)) def main(): global RNG RNG = RandomState(0) return foo()
要在这里实现可复现性,我们需要控制 bar()和 baz()的评估顺序,即使它们之间没有显式的数据依赖关系。这种由可复现性(#2)引起的顺序要求违反了可并行性(#5),并且与 JAX 或 XLA 的功能语义(#6)不符合,在其中子表达式可以以任何顺序评估。即使我们不需要可复现性,因此允许任何评估顺序,由于需要更新共享状态,跨调用的并行化(#5)仍将变得困难。此外,由于需要在 Python 和任何编译代码中访问和维护相同的 PRNG 状态,这种模型可能会导致工程挑战,以实现编译不变性(#3)和扩展到多个副本(#6)。最后,表达力受到限制(#1),因为没有办法让 foo()调用 bar()或 baz()而不影响其自身的(隐式)PRNG 状态。
是否模型支持向量化(#4)取决于一些额外的细节。在 Numpy 中,PRNG 向量化受到顺序等效保证的限制:
In [1]: rng = np.random.RandomState(0) In [2]: rng.randn(2) Out[2]: array([1.76405235, 0.40015721]) In [3]: rng = np.random.RandomState(0) In [4]: np.stack([rng.randn() for _ in range(2)]) Out[4]: array([1.76405235, 0.40015721])
允许在生成数组的原始 PRNG 函数调用中进行向量化(#4)(例如,使用形状参数调用 rand()),我们放弃了这种顺序等效保证。这种向量化可以由本节讨论的任何三种编程模型支持,尽管它激励我们按照下一节中描述的基于计数器的 PRNG 实现来实现。
有状态 PRNG 用户编程模型前景不佳。以下是一个功能模型的示例,但缺少我们称之为分割的关键要素:
def foo(rng_1): y, rng_2 = baz(rng_1) z, rng_3 = bar(rng_2) return y + z, rng_3 def bar(x, rng): val, new_rng = rand(rng, (3, 4)) return val, new_rng def baz(x, rng): val, new_rng = rand(rng, (3, 4)) return val, new_rng def main(): foo(RandomState(0))
这个模型明确地通过所有生成随机值的函数(原始或非原始)线程化 PRNG 状态:也就是说,每个随机函数都必须接受并返回状态。现在,在 foo() 中,调用 baz() 和调用 bar() 之间存在显式的数据依赖关系,因此数据流(以及顺序)是显式的,并且与 JAX 的现有语义相符(#7),与先前的模型不同。这种显式线程化还可以使语义不变到编译边界(#3)。
对程序员来说,显式线程化是不方便的。但更糟糕的是,它实际上并没有改进表达能力(#1):foo() 仍然没有办法在调用 bar() 或 baz() 的同时保持自己的 PRNG 状态。没有了解其调用者或它们调用的子例程,函数必须在每个地方防御性地传入和返回 rng 状态。此外,它也没有改进并行化的前景(#5)或扩展到多个副本的能力(#6),因为一切仍然是顺序的,即使在功能编程意义上顺序被显式地表示出来。
简而言之,通过显式地线程化状态使代码功能化并不能实现我们的表达性目标(#1)和性能目标(#5,#6)。
在前面的两种模型中的关键问题是存在过多的顺序依赖。为了减少顺序依赖性,我们使用功能性splittable PRNGs。分割是一种机制,用于在保持通常理想的 PRNG 属性的同时‘分叉’新的 PRNG 状态为两个 PRNG 状态(两个新流可以在计算上并行化并产生独立的随机值,即它们的行为类似于multistreams)。
def foo(rng_1): rng_2, rng_3 = split(rng_1, 2) return bar(rng_2) + baz(rng_3) def bar(x, rng): return rand(rng, (3, 4)) def baz(x, rng): return rand(rng, (3, 4)) def main(): foo(RandomState(0))
一些需要注意的点:
- 调用 bar() 和 baz() 的顺序无关紧要,它们可以以任何顺序评估,而不会影响结果的值,这解决了剩下的性能目标(#5,#6),
- 函数不需要返回更新版本的 PRNG,并且可以直接调用随机子例程而不影响现有的 PRNG 状态,从而改善了来自其他功能模型的表达能力(#1)。
例如并未显示,但由于选择(2),推进 PRNG 状态的唯一方法是调用 split()。也就是说,我们有两种实现(1)的方式,它们在是否将显式调用 split() 添加到用户程序上有所不同,就像上面的例子一样,或者改为加入显式线程。我们更喜欢前者,即显式分割版本,因为我们可以轻松地基于它实现显式线程版本。
设计
我们可以使用 基于计数器的 PRNG 设计,特别是如 Parallel random numbers: as easy as 1, 2, 3 中描述的 Threefry 哈希函数。我们利用计数器实现高效的向量化:对于给定的密钥,我们可以通过在整数范围 [k + 1, …, k + sample_size] 上映射哈希函数,以向量化的方式生成值数组。我们与哈希函数一起使用密钥实现 可分割 PRNGs:也就是说,分割是从现有密钥生成两个新密钥的一种方式。
type Sample = Int256 type Key = Sample -- important identification for splitting type Count = Int32 hash :: Key -> Count -> Int256 -- output type equal to Key and Sample split :: Key -> (Key, Key) split key = (hash key 0, hash key 1) draw_samples :: Key -> Int -> [Sample] draw_samples key n = map (hash key) [1..n]
令人惊讶的是,抽取样本与分割非常相似!关键在于输出类型的差异(即使类型被识别为相同):在一种情况下,该值用于形成感兴趣的随机样本(例如,将随机比特转换为表示随机正态分布的 Float),而在另一种情况下,该值用作进一步哈希的键。
哈希函数参数的不对称性,即 Key 和 Count 类型,后者可以通过任意数量的计算轻松推进,因为我们只需增加整数值,而前者只能通过哈希来推进。这就是为什么我们在向量化中使用计数参数的原因。
更现实的示例用户程序
当步骤需要 PRNG 时(也许是为了 dropout 或 VAE 训练),在主机上的训练循环可能如下所示:
rng = lax.rng.new_rng() for i in xrange(num_steps): rng, rng_input = lax.rng.split(rng) params = compiled_update(rng_input, params, next(batches))
注意,我们将用户负担了显式分割的随机数生成器,但代码根本不需要返回随机数生成器。
以下是我们如何在 stax 神经网络构建器库中使用此 PRNG 模型来实现 dropout:
def Dropout(rate, mode='train'): def init_fun(input_shape): return input_shape, () def apply_fun(rng, params, inputs): if mode == 'train': keep = lax.random.bernoulli(rng, rate, inputs.shape) return np.where(keep, inputs / rate, 0) else: return inputs return init_fun, apply_fun
这里的 rng 值只是用于哈希的密钥,而不是特殊对象。rng 参数传递给每个 apply_fun,因此需要在串行和并行组合器中进行处理以进行分割:
def serial(*layers): init_funs, apply_funs = zip(*layers) def init_fun(input_shape): ... def apply_fun(rng, params, inputs): rngs = split(rng, len(layers)) for rng, param, apply_fun in zip(rngs, params, apply_funs): inputs = apply_fun(rng, param, inputs) return inputs return init_fun, apply_fun def parallel(*layers): init_funs, apply_funs = zip(*layers) def init_fun(input_shape): ... def apply_fun(rng, params, inputs): rngs = split(rng, len(layers)) return [f(r, p, x) for f, r, p, x in zip(apply_funs, rngs, params, inputs)] return init_fun, apply_fun
在这里,我们使用了一个简单的扩展版本的 split,可以生成多个副本。
折衷和替代方案
- 我们没有利用任何设备硬件 PRNG。
- 我们目前无法控制所有后端的硬件 PRNG 状态。
- 即使我们这样做了,它也会依赖后端,并且我们可能需要在随机调用之间引入顺序依赖关系,以确保确定性排序和因此可重复性。
- 我们不知道任何软件 PRNG 应成为瓶颈的工作负载。
- 我们可以考虑提供额外的 API,允许用户访问硬件 PRNG,这样他们就可以放弃其他的期望(比如严格的可重现性)。
- 我们放弃了顺序等效的保证,即在一次调用中创建随机数组与逐个创建扁平化数组的随机元素产生相同的值。
- 这个属性很可能与向量化不兼容(一个高优先级)。
- 我们不知道有哪些用户或示例认为此属性很重要。
- 用户可以在此 API 之上编写一层以提供此保证。
- 我们不能完全遵循
numpy.random
的 API。
为 JAX-可变换函数定义自定义 JVP/VJP 规则
原文:
jax.readthedocs.io/en/latest/jep/2026-custom-derivatives.html
这是一个设计文档,解释了关于设计和实现jax.custom_jvp
和jax.custom_vjp
背后的一些思路。有关面向用户的文档,请参阅教程笔记本。
在 JAX 中有两种定义微分规则的方法:
- 使用
jax.custom_jvp
和jax.custom_vjp
为已经可以 JAX-变换的 Python 函数定义自定义微分规则;和 - 定义新的
core.Primitive
实例及其所有转换规则,例如调用来自其他系统(如求解器、仿真器或通用数值计算系统)的函数。
本文只涉及 #1。
内容
- 目标
- 非目标
- 主要问题描述
- vmap-removes-custom-jvp 语义问题
- Python 灵活性问题
- 解决方案思路
- 实现注意事项
目标
我们希望用户可以定制其代码的正向和/或反向模式微分行为。这种定制
- 应该具有清晰一致的语义,以及其工作方式与其他 JAX 变换如何组合;和
- 应该灵活地支持像Autograd和PyTorch中的使用案例和工作流,包括涉及 Python 控制流的微分和 NaN 调试工作流。
作为JAX 开发者,我们希望编写库函数,如logit
和expit
,这些函数在其他原语的基础上定义,但在微分的目的上具有类似原语的行为,因此我们希望为它们定义自定义微分规则,这些规则可能更稳定或更高效。特别是,我们不想为logit
和expit
等函数指定vmap
或jit
规则。
作为一个延伸目标,我们希望将 JAX 打造成一个非常适合希望为高阶函数如 fixed_point
、odeint
等添加自定义微分规则的高级用户的环境;这个设计文档不会解决这个问题,但我们希望能够确保我们不会排除解决这个问题的好方法。
也就是说,我们的主要目标是
次要目标是 3. 清理和简化用户体验(符号零、kwargs 等)4. 朝着用户能够轻松添加 fixed_point
、odeint
、root
等的世界迈进。
总体而言,我们希望关闭 #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938,并替换自 #636, #818 和其他问题中的 custom_transforms 机制。
非目标
下面是我们不打算实现的目标:
custom_transforms
机制旨在提供一个转换通用机制,用于定制行为,原则上(尽管在实践中从未真正使用)允许用户定制任何转换的规则,同时以某种方式继承其他转换的“透明”行为。相反,我们仅打算解决微分的定制化问题(分别为 JVP 和 VJP)。 实际上只有微分是被请求的用例,通过专门用于微分,我们可以减少复杂性并提高灵活性。要控制所有规则,用户可以直接编写一个原始函数。- 我们不打算将数学美学放在用户便利性、实现简单性及清晰性之上。特别是,虽然自定义 VJP 签名
a -> (b, CT b --o CT a)
在数学上是美观的,但如果由于返回类型中的闭包而在 Python 机制中实现困难,我们愿意采取一些更显式处理残差的方法。 - 序列化支持,即以分阶段序列化的程序表示形式加载并进行更多 JAX 转换,而不仅仅是评估,目前不在这些自定义 JVP/VJP 转换规则的范围内。序列化不仅对希望保存计算表示形式(并在加载后转换它)的研究人员有用,还可能考虑将 jaxpr 转换实现在 Python 之外,或者将 jaxprs 作为 MLIR 语言的一部分。通过将其定义为这一设计的非目标,我们在可存放 Python 可调用对象的位置上拥有更少的约束。
JAX 中文文档(十一)(2)https://developer.aliyun.com/article/1559780