JAX 中文文档(十四)(3)

简介: JAX 中文文档(十四)

JAX 中文文档(十四)(2)https://developer.aliyun.com/article/1559756


jax.random 模块

原文:jax.readthedocs.io/en/latest/jax.random.html

伪随机数生成的实用程序。

jax.random 包提供了多种例程,用于确定性生成伪随机数序列。

基本用法

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches)) 

PRNG keys

与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key() 函数生成:

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0] 

然后,可以在 JAX 的任何随机数生成例程中使用该 key:

>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:

>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

如果需要新的随机数,可以使用 jax.random.split() 生成新的子 key:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32) 

注意

类型化的 key 数组,例如上述 key,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32 数组表示,其最终维度表示 key 的位级表示。

两种形式的 key 数组仍然可以通过 jax.random 模块创建和使用。新式的类型化 key 数组使用 jax.random.key() 创建。传统的 uint32 key 数组使用 jax.random.PRNGKey() 创建。

要在两者之间进行转换,使用 jax.random.key_data()jax.random.wrap_key_data()。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。

否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:

  • 它们有一个额外的尾维度。
  • 它们具有数字数据类型 (uint32),允许进行通常不用于 key 的操作,例如整数算术。
  • 它们不包含有关 RNG 实现的信息。当传统 key 传递给 jax.random 函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。

要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263

高级

设计和背景

TLDR:JAX PRNG = Threefry counter PRNG + 一个功能数组导向的 分裂模型

更多详细信息,请参阅 docs/jep/263-prng.md

总结一下,JAX PRNG 还包括但不限于以下要求:

  1. 确保可重现性,
  2. 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。

高级 RNG 配置

JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。

  • “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
  • “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。
  • 这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。

不使用默认 RNG 的可能原因是:

  1. 可能编译速度较慢(特别是对于 Google Cloud TPU)
  2. 在 TPU 上执行速度较慢
  3. 不支持高效的自动分片/分区

这里是一个简短的总结:

属性 Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
在 TPU 上最快
可以高效分片(使用 pjit)
在分片中相同
在 CPU/GPU/TPU 上相同
在 JAX/XLA 版本间相同

(*): 设置了jax_threefry_partitionable=1

(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于  jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在  jax.random.split 和 jax.random.fold_in  中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。

要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

API 参考

密钥创建与操作

PRNGKey(seed, *[, impl]) 给定整数种子创建伪随机数生成器(PRNG)密钥。
key(seed, *[, impl]) 给定整数种子创建伪随机数生成器(PRNG)密钥。
key_data(密钥) 恢复 PRNG 密钥数组下的密钥数据位。
wrap_key_data(key_bits_array, *[, impl]) 将密钥数据位数组包装成 PRNG 密钥数组。
fold_in(key, data) 将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。
split(key[, num]) 将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。
clone(key) 克隆一个密钥以便重复使用。

随机抽样器

ball(key, d[, p, shape, dtype]) 从单位 Lp 球中均匀采样。
bernoulli(key[, p, shape]) 采样给定形状和均值的伯努利分布随机值。
beta(key, a, b[, shape, dtype]) 采样给定形状和浮点数数据类型的贝塔分布随机值。
binomial(key, n, p[, shape, dtype]) 采样给定形状和浮点数数据类型的二项分布随机值。
bits(key[, shape, dtype]) 以无符号整数的形式采样均匀比特。
categorical(key, logits[, axis, shape]) 从分类分布中采样随机值。
cauchy(key[, shape, dtype]) 采样给定形状和浮点数数据类型的柯西分布随机值。
chisquare(key, df[, shape, dtype]) 采样给定形状和浮点数数据类型的卡方分布随机值。
choice(key, a[, shape, replace, p, axis]) 从给定数组中生成随机样本。
dirichlet(key, alpha[, shape, dtype]) 采样给定形状和浮点数数据类型的狄利克雷分布随机值。
double_sided_maxwell(key, loc, scale[, …]) 从双边 Maxwell 分布中采样。
exponential(key[, shape, dtype]) 采样给定形状和浮点数数据类型的指数分布随机值。
f(key, dfnum, dfden[, shape, dtype]) 采样给定形状和浮点数数据类型的 F 分布随机值。
gamma(key, a[, shape, dtype]) 采样给定形状和浮点数数据类型的伽马分布随机值。
generalized_normal(key, p[, shape, dtype]) 从广义正态分布中采样。
geometric(key, p[, shape, dtype]) 采样给定形状和浮点数数据类型的几何分布随机值。
gumbel(key[, shape, dtype]) 采样给定形状和浮点数数据类型的 Gumbel 分布随机值。
laplace(key[, shape, dtype]) 采样给定形状和浮点数数据类型的拉普拉斯分布随机值。
loggamma(key, a[, shape, dtype]) 采样给定形状和浮点数数据类型的对数伽马分布随机值。
logistic(key[, shape, dtype]) 采样给定形状和浮点数数据类型的 logistic 随机值。
lognormal(key[, sigma, shape, dtype]) 采样给定形状和浮点数数据类型的对数正态分布随机值。
maxwell(key[, shape, dtype]) 从单边 Maxwell 分布中采样。
multivariate_normal(key, mean, cov[, shape, …]) 采样给定均值和协方差的多变量正态分布随机值。
normal(key[, shape, dtype]) 采样给定形状和浮点数数据类型的标准正态分布随机值。
orthogonal(key, n[, shape, dtype]) 从正交群 O(n) 中均匀采样。
pareto(key, b[, shape, dtype]) 采样给定形状和浮点数数据类型的帕累托分布随机值。
permutation(key, x[, axis, independent]) 返回随机排列的数组或范围。
poisson(key, lam[, shape, dtype]) 采样给定形状和整数数据类型的泊松分布随机值。
rademacher(key[, shape, dtype]) 从 Rademacher 分布中采样。
randint(key, shape, minval, maxval[, dtype]) 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。
[rayleigh(key, scale[, shape, dtype]) 用给定的形状和浮点数数据类型示例瑞利随机值。
t(key, df[, shape, dtype]) 用给定的形状和浮点数数据类型示例学生 t 分布随机值。
triangular(key, left, mode, right[, shape, …]) 用给定的形状和浮点数数据类型示例三角形随机值。
truncated_normal(key, lower, upper[, shape, …]) 用给定的形状和数据类型示例截断标准正态随机值。
uniform(key[, shape, dtype, minval, maxval]) 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。
[wald(key, mean[, shape, dtype]) 用给定的形状和浮点数数据类型示例瓦尔德随机值。
weibull_min(key, scale, concentration[, …]) 从威布尔分布中采样。


JAX 中文文档(十四)(4)https://developer.aliyun.com/article/1559759

相关文章
|
2天前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
14 3
|
2天前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
13 3
|
2天前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
9 2
|
2天前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
7 1
|
2天前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
8 1
|
2天前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
6 1
|
2天前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
10 1
|
2天前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
9 0
|
2天前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
7 0
|
2天前
|
API 异构计算 索引
JAX 中文文档(十四)(2)
JAX 中文文档(十四)
8 0