JAX 中文文档(一)(2)https://developer.aliyun.com/article/1559830
🔪 随机数
如果所有因糟糕的
rand()
而存疑的科学论文都从图书馆书架上消失,每个书架上会有一个拳头大小的空白。 - Numerical Recipes
RNG 和状态
您习惯于从 numpy 和其他库中使用有状态的伪随机数生成器(PRNG),这些库在幕后巧妙地隐藏了许多细节,为您提供了伪随机性的丰富源泉:
print(np.random.random()) print(np.random.random()) print(np.random.random())
0.9818293835329528 0.06574727326903418 0.3930007618911092
在底层,numpy 使用Mersenne Twister PRNG 来驱动其伪随机函数。该 PRNG 具有(2^{19937}-1)的周期,并且在任何时候可以由624 个 32 位无符号整数和一个表示已使用的“熵”量的位置来描述。
np.random.seed(0) rng_state = np.random.get_state() # print(rng_state) # --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, # 2481403966, 4042607538, 337614300, ... 614 more numbers..., # 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
这个伪随机状态向量在每次需要随机数时都会在幕后自动更新,“消耗”Mersenne Twister 状态向量中的 2 个 uint32:
_ = np.random.uniform() rng_state = np.random.get_state() #print(rng_state) # --> ('MT19937', array([2443250962, 1093594115, 1878467924, # ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) # Let's exhaust the entropy in this PRNG statevector for i in range(311): _ = np.random.uniform() rng_state = np.random.get_state() #print(rng_state) # --> ('MT19937', array([2443250962, 1093594115, 1878467924, # ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) # Next call iterates the RNG state for a new batch of fake "entropy". _ = np.random.uniform() rng_state = np.random.get_state() # print(rng_state) # --> ('MT19937', array([1499117434, 2949980591, 2242547484, # 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
魔法 PRNG 状态的问题在于很难推断它在不同线程、进程和设备中的使用和更新方式,并且在熵的生成和消耗细节对最终用户隐藏时,非常容易出错。
Mersenne Twister PRNG 也被认为存在一些问题,它具有较大的 2.5kB 状态大小,导致初始化问题很多。它在现代的 BigCrush 测试中失败,并且通常速度较慢。
JAX PRNG
相反,JAX 实现了一个显式的PRNG,其中熵的生成和消耗通过显式传递和迭代 PRNG 状态来处理。JAX 使用一种现代化的Threefry 基于计数器的 PRNG,它是可分裂的。也就是说,其设计允许我们将 PRNG 状态分叉成新的 PRNG,以用于并行随机生成。
随机状态由一个我们称之为密钥的特殊数组元素描述:
from jax import random key = random.key(0) key
Array((), dtype=key<fry>) overlaying: [0 0]
JAX 的随机函数从 PRNG 状态生成伪随机数,但不会改变状态!
复用相同的状态会导致悲伤和单调,剥夺最终用户生命力的混乱:
print(random.normal(key, shape=(1,))) print(key) # No no no! print(random.normal(key, shape=(1,))) print(key)
[-0.20584226] Array((), dtype=key<fry>) overlaying: [0 0] [-0.20584226] Array((), dtype=key<fry>) overlaying: [0 0]
相反,我们分割PRNG 以在每次需要新的伪随机数时获得可用的子密钥:
print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) print(" \---SPLIT --> new key ", key) print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying: [0 0] \---SPLIT --> new key Array((), dtype=key<fry>) overlaying: [4146024105 967050713] \--> new subkey Array((), dtype=key<fry>) overlaying: [2718843009 1272950319] --> normal [-1.2515389]
我们传播密钥并在需要新的随机数时生成新的子密钥:
print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) print(" \---SPLIT --> new key ", key) print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying: [4146024105 967050713] \---SPLIT --> new key Array((), dtype=key<fry>) overlaying: [2384771982 3928867769] \--> new subkey Array((), dtype=key<fry>) overlaying: [1278412471 2182328957] --> normal [-0.58665055]
我们可以同时生成多个子密钥:
key, *subkeys = random.split(key, 4) for subkey in subkeys: print(random.normal(subkey, shape=(1,)))
[-0.37533438] [0.98645043] [0.14553197]
🔪 控制流
✔ python 控制流 + 自动微分 ✔
如果您只想将grad
应用于您的 Python 函数,可以使用常规的 Python 控制流结构,没有问题,就像使用Autograd(或 Pytorch 或 TF Eager)一样。
def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x print(grad(f)(2.)) # ok! print(grad(f)(4.)) # ok!
12.0 -4.0
python 控制流 + JIT
使用jit
进行控制流更为复杂,默认情况下具有更多约束。
这个可以工作:
@jit def f(x): for i in range(3): x = 2 * x return x print(f(3))
24
这样也可以:
@jit def g(x): y = 0. for i in range(x.shape[0]): y = y + x[i] return y print(g(jnp.array([1., 2., 3.])))
6.0
但默认情况下,这样不行:
@jit def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x # This will fail! f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].. The error occurred while tracing the function f at /tmp/ipykernel_1227/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
怎么回事!?
当我们jit
编译一个函数时,通常希望编译一个适用于许多不同参数值的函数版本,以便我们可以缓存和重复使用编译代码。这样我们就不必在每次函数评估时重新编译。
例如,如果我们在数组jnp.array([1., 2., 3.], jnp.float32)
上评估@jit
函数,我们可能希望编译代码,以便我们可以重复使用它来在jnp.array([4., 5., 6.], jnp.float32)
上评估函数,从而节省编译时间。
要查看适用于许多不同参数值的 Python 代码视图,JAX 会跟踪抽象值,这些抽象值表示可能输入集合的集合。有关不同的转换使用不同的抽象级别,详见多个不同的抽象级别。
默认情况下,jit
会在ShapedArray
抽象级别上跟踪您的代码,其中每个抽象值表示具有固定形状和 dtype 的所有数组值的集合。例如,如果我们使用抽象值ShapedArray((3,), jnp.float32)
进行跟踪,我们会得到可以重复使用于相应数组集合中的任何具体值的函数视图。这意味着我们可以节省编译时间。
但这里有一个权衡:如果我们在ShapedArray((), jnp.float32)
上跟踪 Python 函数,它不专注于具体值,当我们遇到像if x < 3
这样的行时,表达式x < 3
会评估为表示集合{True, False}
的抽象ShapedArray((), jnp.bool_)
。当 Python 尝试将其强制转换为具体的True
或False
时,我们会收到错误:我们不知道应该选择哪个分支,无法继续跟踪!权衡是,使用更高级别的抽象,我们获得 Python 代码的更一般视图(因此节省重新编译的时间),但我们需要更多约束来完成跟踪。
好消息是,您可以自行控制这种权衡。通过启用jit
对更精细的抽象值进行跟踪,您可以放宽跟踪约束。例如,使用jit
的static_argnums
参数,我们可以指定在某些参数的具体值上进行跟踪。下面是这个例子函数:
def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x f = jit(f, static_argnums=(0,)) print(f(2.))
12.0
下面是另一个例子,这次涉及循环:
def f(x, n): y = 0. for i in range(n): y = y + x[i] return y f = jit(f, static_argnums=(1,)) f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)
实际上,循环被静态展开。JAX 也可以在更高的抽象级别进行追踪,比如 Unshaped
,但目前对于任何变换来说这都不是默认的。
️⚠️ 具有参数-值相关形状的函数
这些控制流问题也以更微妙的方式出现:我们希望 jit 的数值函数不能根据参数 值 来特化内部数组的形状(在参数 形状 上特化是可以的)。举个简单的例子,让我们创建一个函数,其输出恰好依赖于输入变量 length
。
def example_fun(length, val): return jnp.ones((length,)) * val # un-jit'd works fine print(example_fun(5, 4))
[4\. 4\. 4\. 4\. 4.]
bad_example_jit = jit(example_fun) # this will fail: bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. The error occurred while tracing the function example_fun at /tmp/ipykernel_1227/1210496444.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums tells JAX to recompile on changes at these argument positions: good_example_jit = jit(example_fun, static_argnums=(0,)) # first compile print(good_example_jit(10, 4)) # recompiles print(good_example_jit(5, 4))
[4\. 4\. 4\. 4\. 4\. 4\. 4\. 4\. 4\. 4.] [4\. 4\. 4\. 4\. 4.]
如果在我们的示例中 length
很少更改,那么 static_argnums
就会很方便,但如果它经常更改,那将是灾难性的!
最后,如果您的函数具有全局副作用,JAX 的追踪器可能会导致一些奇怪的事情发生。一个常见的坑是尝试在 jit 函数中打印数组:
@jit def f(x): print(x) y = 2 * x print(y) return y f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)
结构化控制流原语
JAX 中有更多控制流选项。假设您想避免重新编译但仍想使用可追踪的控制流,并避免展开大循环。那么您可以使用这四个结构化的控制流原语:
lax.cond
可微分lax.while_loop
前向模式可微分lax.fori_loop
前向模式可微分;如果端点是静态的,则前向和反向模式均可微分。lax.scan
可微分
cond
python 等效:
def cond(pred, true_fun, false_fun, operand): if pred: return true_fun(operand) else: return false_fun(operand)
from jax import lax operand = jnp.array([0.]) lax.cond(True, lambda x: x+1, lambda x: x-1, operand) # --> array([1.], dtype=float32) lax.cond(False, lambda x: x+1, lambda x: x-1, operand) # --> array([-1.], dtype=float32)
Array([-1.], dtype=float32)
jax.lax
还提供了另外两个函数,允许根据动态谓词进行分支:
lax.select
类似于lax.cond
的批处理版本,选择项表达为预先计算的数组而不是函数。lax.switch
类似于lax.cond
,但允许在任意数量的可调用选项之间进行切换。
另外,jax.numpy
提供了几个 numpy 风格的接口:
jnp.where
的三个参数是lax.select
的 numpy 风格封装。jnp.piecewise
是lax.switch
的 numpy 风格封装,但是根据一系列布尔条件而不是单个标量索引进行切换。jnp.select
的 API 类似于jnp.piecewise
,但选择项是作为预先计算的数组而不是函数给出的。它是基于多次调用lax.select
实现的。
while_loop
python 等效:
def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
init_val = 0 cond_fun = lambda x: x<10 body_fun = lambda x: x+1 lax.while_loop(cond_fun, body_fun, init_val) # --> array(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)
fori_loop
python 等效:
def fori_loop(start, stop, body_fun, init_val): val = init_val for i in range(start, stop): val = body_fun(i, val) return val
init_val = 0 start = 0 stop = 10 body_fun = lambda i,x: x+i lax.fori_loop(start, stop, body_fun, init_val) # --> array(45, dtype=int32)
Array(45, dtype=int32, weak_type=True)
总结
[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \ \hline \ \textrm{if} & ❌ & ✔ \ \textrm{for} & ✔* & ✔\ \textrm{while} & ✔* & ✔\ \textrm{lax.cond} & ✔ & ✔\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\ \textrm{lax.scan} & ✔ & ✔\ \hline \end{array} \end{split}
\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \ \hline \ \textrm{if} & ❌ & ✔ \ \textrm{for} & ✔* & ✔\ \textrm{while} & ✔* & ✔\ \textrm{lax.cond} & ✔ & ✔\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\ \textrm{lax.scan} & ✔ & ✔\ \hline \end{array} \end{split}\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \ \hline \ \textrm{if} & ❌ & ✔ \ \textrm{for} & ✔* & ✔\ \textrm{while} & ✔* & ✔\ \textrm{lax.cond} & ✔ & ✔\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\ \textrm{lax.scan} & ✔ & ✔\ \hline \end{array} \end{split}
\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \ \hline \ \textrm{if} & ❌ & ✔ \ \textrm{for} & ✔* & ✔\ \textrm{while} & ✔* & ✔\ \textrm{lax.cond} & ✔ & ✔\ \textrm{lax.while_loop} & ✔ & \textrm{前向}\ \textrm{lax.fori_loop} & ✔ & \textrm{前向}\ \textrm{lax.scan} & ✔ & ✔\ \hline \end{array} \end{split}](\ast) = 参数-值-独立循环条件 - 展开循环
JAX 中文文档(一)(4)https://developer.aliyun.com/article/1559834