JAX 中文文档(一)(1)https://developer.aliyun.com/article/1559829
🔪 JAX - 锋利的部分 🔪
原文:
jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
levskaya@ mattjj@
在意大利乡间漫步时,人们会毫不犹豫地告诉您,JAX 具有 “una anima di pura programmazione funzionale”。
JAX 是一种用于表达和组合数值程序转换的语言。JAX 还能够为 CPU 或加速器(GPU/TPU)编译数值程序。对于许多数值和科学程序,JAX 表现出色,但前提是它们必须按照我们下面描述的某些约束条件编写。
import numpy as np from jax import grad, jit from jax import lax from jax import random import jax import jax.numpy as jnp
🔪 纯函数
JAX 的转换和编译设计仅适用于函数式纯的 Python 函数:所有输入数据通过函数参数传递,所有结果通过函数结果输出。纯函数如果以相同的输入调用,将始终返回相同的结果。
下面是一些函数示例,这些函数不是函数式纯的,因此 JAX 的行为与 Python 解释器不同。请注意,这些行为并不由 JAX 系统保证;正确使用 JAX 的方法是仅在函数式纯 Python 函数上使用它。
def impure_print_side_effect(x): print("Executing function") # This is a side-effect return x # The side-effects appear during the first run print ("First call: ", jit(impure_print_side_effect)(4.)) # Subsequent runs with parameters of same type and shape may not show the side-effect # This is because JAX now invokes a cached compilation of the function print ("Second call: ", jit(impure_print_side_effect)(5.)) # JAX re-runs the Python function when the type or shape of the argument changes print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function First call: 4.0 Second call: 5.0 Executing function Third call, different type: [5.]
g = 0. def impure_uses_globals(x): return x + g # JAX captures the value of the global during the first run print ("First call: ", jit(impure_uses_globals)(4.)) g = 10. # Update the global # Subsequent runs may silently use the cached value of the globals print ("Second call: ", jit(impure_uses_globals)(5.)) # JAX re-runs the Python function when the type or shape of the argument changes # This will end up reading the latest value of the global print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0 Second call: 5.0 Third call, different type: [14.]
g = 0. def impure_saves_global(x): global g g = x return x # JAX runs once the transformed function with special Traced values for arguments print ("First call: ", jit(impure_saves_global)(4.)) print ("Saved global: ", g) # Saved global has an internal JAX value
First call: 4.0 Saved global: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
即使一个 Python 函数在内部实际上使用了有状态的对象,只要它不读取或写入外部状态,它就可以是函数式纯的:
def pure_uses_internal_state(x): state = dict(even=0, odd=0) for i in range(10): state['even' if i % 2 == 0 else 'odd'] += x return state['even'] + state['odd'] print(jit(pure_uses_internal_state)(5.))
50.0
不建议在希望jit
的任何 JAX 函数中使用迭代器或任何控制流原语。原因是迭代器是一个引入状态以检索下一个元素的 Python 对象。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,有一些尝试在 JAX 中使用迭代器的错误示例。其中大多数会返回错误,但有些会给出意外的结果。
import jax.numpy as jnp import jax.lax as lax from jax import make_jaxpr # lax.fori_loop array = jnp.arange(10) print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45 iterator = iter(range(10)) print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0 # lax.scan def func11(arr, extra): ones = jnp.ones(arr.shape) def body(carry, aelems): ae1, ae2 = aelems return (carry + ae1 * ae2 + extra, carry) return lax.scan(body, 0., (arr, ones)) make_jaxpr(func11)(jnp.arange(16), 5.) # make_jaxpr(func11)(iter(range(16)), 5.) # throws error # lax.cond array_operand = jnp.array([0.]) lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand) iter_operand = iter(range(10)) # lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error
45 0
🔪 原地更新
在 Numpy 中,您习惯于执行以下操作:
numpy_array = np.zeros((3,3), dtype=np.float32) print("original array:") print(numpy_array) # In place, mutating update numpy_array[1, :] = 1.0 print("updated array:") print(numpy_array)
original array: [[0\. 0\. 0.] [0\. 0\. 0.] [0\. 0\. 0.]] updated array: [[0\. 0\. 0.] [1\. 1\. 1.] [0\. 0\. 0.]]
然而,如果我们尝试在 JAX 设备数组上就地更新,我们会收到错误!(☉_☉)
%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32) # In place update of JAX's array will yield an error! jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
允许变量在原地变异会使程序分析和转换变得困难。JAX 要求程序是纯函数。
相反,JAX 提供了对 JAX 数组上的 .at
属性进行函数式数组更新。
️⚠️ 在 jit
的代码中和 lax.while_loop
或 lax.fori_loop
中,切片的大小不能是参数 值 的函数,而只能是参数 形状 的函数 — 切片的起始索引没有此类限制。有关此限制的更多信息,请参阅下面的 控制流 部分。
数组更新:x.at[idx].set(y)
例如,上述更新可以写成:
updated_array = jax_array.at[1, :].set(1.0) print("updated array:\n", updated_array)
updated array: [[0\. 0\. 0.] [1\. 1\. 1.] [0\. 0\. 0.]]
JAX 的数组更新函数与其 NumPy 版本不同,是在原地外执行的。也就是说,更新后的数组作为新数组返回,原始数组不会被更新修改。
print("original array unchanged:\n", jax_array)
original array unchanged: [[0\. 0\. 0.] [0\. 0\. 0.] [0\. 0\. 0.]]
然而,在jit编译的代码内部,如果x.at[idx].set(y)
的输入值 x
没有被重用,编译器会优化数组更新以进行原地操作。
使用其他操作的数组更新
索引数组更新不仅限于覆盖值。例如,我们可以进行索引加法如下:
print("original array:") jax_array = jnp.ones((5, 6)) print(jax_array) new_jax_array = jax_array.at[::2, 3:].add(7.) print("new array post-addition:") print(new_jax_array)
original array: [[1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 1\. 1\. 1.]] new array post-addition: [[1\. 1\. 1\. 8\. 8\. 8.] [1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 8\. 8\. 8.] [1\. 1\. 1\. 1\. 1\. 1.] [1\. 1\. 1\. 8\. 8\. 8.]]
有关索引数组更新的更多详细信息,请参阅.at
属性的文档。
🔪 超出边界索引
在 NumPy 中,当您索引数组超出其边界时,通常会抛出错误,例如:
np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10
然而,在加速器上运行的代码中引发错误可能会很困难或不可能。因此,JAX 必须为超出边界的索引选择一些非错误行为(类似于无效的浮点算术结果为NaN
的情况)。当索引操作是数组索引更新时(例如index_add
或类似的原语),将跳过超出边界的索引;当操作是数组索引检索时(例如 NumPy 索引或类似的原语),索引将夹紧到数组的边界,因为必须返回某些内容。例如,数组的最后一个值将从此索引操作中返回:
jnp.arange(10)[11]
Array(9, dtype=int32)
如果您希望对超出边界索引的行为有更精细的控制,可以使用ndarray.at
的可选参数;例如:
jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)
注意由于这种索引检索行为,像jnp.nanargmin
和jnp.nanargmax
这样的函数在由 NaN 组成的切片中返回-1,而 NumPy 会抛出错误。
还请注意,由于上述两种行为不是互为反操作,反向模式自动微分(将索引更新转换为索引检索及其反之)将不会保留超出边界索引的语义。因此,将 JAX 中的超出边界索引视为未定义行为可能是个好主意。
🔪 非数组输入:NumPy vs. JAX
NumPy 通常可以接受 Python 列表或元组作为其 API 函数的输入:
np.sum([1, 2, 3])
np.int64(6)
JAX 在这方面有所不同,通常会返回有用的错误:
jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
这是一个有意的设计选择,因为向追踪函数传递列表或元组可能导致性能下降,而这种性能下降可能很难检测到。
例如,请考虑允许列表输入的jnp.sum
的以下宽松版本:
def permissive_sum(x): return jnp.sum(jnp.array(x)) x = list(range(10)) permissive_sum(x)
Array(45, dtype=int32)
输出与预期相符,但这隐藏了底层的潜在性能问题。在 JAX 的追踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为单独的 JAX 变量,并分别处理和推送到设备。这可以在上面的permissive_sum
函数的 jaxpr 中看到:
make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[] j:i32[]. let k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd bf:i32[] = reduce_sum[axes=(0,)] be in (bf,) }
列表的每个条目都作为单独的输入处理,导致追踪和编译开销随列表大小线性增长。为了避免这样的意外,JAX 避免将列表和元组隐式转换为数组。
如果您希望将元组或列表传递给 JAX 函数,可以首先显式地将其转换为数组:
jnp.sum(jnp.array(x))
Array(45, dtype=int32)
JAX 中文文档(一)(3)https://developer.aliyun.com/article/1559832