JAX 中文文档(十四)(3)

简介: JAX 中文文档(十四)

JAX 中文文档(十四)(2)https://developer.aliyun.com/article/1559756


jax.random 模块

原文:jax.readthedocs.io/en/latest/jax.random.html

伪随机数生成的实用程序。

jax.random 包提供了多种例程,用于确定性生成伪随机数序列。

基本用法

>>> seed = 1701
>>> num_steps = 100
>>> key = jax.random.key(seed)
>>> for i in range(num_steps):
...   key, subkey = jax.random.split(key)
...   params = compiled_update(subkey, params, next(batches)) 

PRNG keys

与 NumPy 和 SciPy 用户习惯的 有状态 伪随机数生成器(PRNGs)不同,JAX 随机函数都要求作为第一个参数传递一个显式的 PRNG 状态。随机状态由我们称之为 key 的特殊数组元素类型描述,通常由 jax.random.key() 函数生成:

>>> from jax import random
>>> key = random.key(0)
>>> key
Array((), dtype=key<fry>) overlaying:
[0 0] 

然后,可以在 JAX 的任何随机数生成例程中使用该 key:

>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

请注意,使用 key 不会修改它,因此重复使用相同的 key 将导致相同的结果:

>>> random.uniform(key)
Array(0.41845703, dtype=float32) 

如果需要新的随机数,可以使用 jax.random.split() 生成新的子 key:

>>> key, subkey = random.split(key)
>>> random.uniform(subkey)
Array(0.10536897, dtype=float32) 

注意

类型化的 key 数组,例如上述 key,在 JAX v0.4.16 中引入。在此之前,key 通常以 uint32 数组表示,其最终维度表示 key 的位级表示。

两种形式的 key 数组仍然可以通过 jax.random 模块创建和使用。新式的类型化 key 数组使用 jax.random.key() 创建。传统的 uint32 key 数组使用 jax.random.PRNGKey() 创建。

要在两者之间进行转换,使用 jax.random.key_data()jax.random.wrap_key_data()。当与 JAX 外部系统(例如将数组导出为可序列化格式)交互或将 key 传递给基于 JAX 的库时,可能需要传统的 key 格式。

否则,建议使用类型化的 key。传统 key 相对于类型化 key 的注意事项包括:

  • 它们有一个额外的尾维度。
  • 它们具有数字数据类型 (uint32),允许进行通常不用于 key 的操作,例如整数算术。
  • 它们不包含有关 RNG 实现的信息。当传统 key 传递给 jax.random 函数时,全局配置设置确定 RNG 实现(参见下文的“高级 RNG 配置”)。

要了解更多关于此升级以及 key 类型设计的信息,请参阅 JEP 9263

高级

设计和背景

TLDR:JAX PRNG = Threefry counter PRNG + 一个功能数组导向的 分裂模型

更多详细信息,请参阅 docs/jep/263-prng.md

总结一下,JAX PRNG 还包括但不限于以下要求:

  1. 确保可重现性,
  2. 良好的并行化,无论是向量化(生成数组值)还是多副本、多核计算。特别是它不应在随机函数调用之间使用顺序约束。

高级 RNG 配置

JAX 提供了几种 PRNG 实现。可以通过可选的 impl 关键字参数选择特定的实现。如果在密钥构造函数中没有传递 impl 选项,则实现由全局 jax_default_prng_impl 配置标志确定。

  • “rbg” 使用 ThreeFry 进行分割,并使用 XLA RBG 进行数据生成。
  • “unsafe_rbg” 仅用于演示目的,使用 RBG 进行分割(使用未经测试的虚构算法)和生成。
  • 这些实验性实现生成的随机流尚未经过任何经验随机性测试(例如 Big Crush)。生成的随机比特可能会在 JAX 的不同版本之间变化。

不使用默认 RNG 的可能原因是:

  1. 可能编译速度较慢(特别是对于 Google Cloud TPU)
  2. 在 TPU 上执行速度较慢
  3. 不支持高效的自动分片/分区

这里是一个简短的总结:

属性 Threefry Threefry* rbg unsafe_rbg rbg** unsafe_rbg**
在 TPU 上最快
可以高效分片(使用 pjit)
在分片中相同
在 CPU/GPU/TPU 上相同
在 JAX/XLA 版本间相同

(*): 设置了jax_threefry_partitionable=1

(**): 设置了XLA_FLAGS=--xla_tpu_spmd_rng_bit_generator_unsafe=1

“rbg” 和 “unsafe_rbg” 之间的区别在于,“rbg” 用于生成随机值时使用了较不稳定/研究较少的哈希函数(但不用于  jax.random.split 或 jax.random.fold_in),而 “unsafe_rbg” 还额外在  jax.random.split 和 jax.random.fold_in  中使用了更不稳定的哈希函数。因此,在不同密钥生成的随机流质量方面不那么安全。

要了解有关 jax_threefry_partitionable 的更多信息,请参阅jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#generating-random-numbers

API 参考

密钥创建与操作

PRNGKey(seed, *[, impl]) 给定整数种子创建伪随机数生成器(PRNG)密钥。
key(seed, *[, impl]) 给定整数种子创建伪随机数生成器(PRNG)密钥。
key_data(密钥) 恢复 PRNG 密钥数组下的密钥数据位。
wrap_key_data(key_bits_array, *[, impl]) 将密钥数据位数组包装成 PRNG 密钥数组。
fold_in(key, data) 将数据折叠到 PRNG 密钥中,形成新的 PRNG 密钥。
split(key[, num]) 将 PRNG 密钥按添加一个前导轴拆分为 num 个新密钥。
clone(key) 克隆一个密钥以便重复使用。

随机抽样器

ball(key, d[, p, shape, dtype]) 从单位 Lp 球中均匀采样。
bernoulli(key[, p, shape]) 采样给定形状和均值的伯努利分布随机值。
beta(key, a, b[, shape, dtype]) 采样给定形状和浮点数数据类型的贝塔分布随机值。
binomial(key, n, p[, shape, dtype]) 采样给定形状和浮点数数据类型的二项分布随机值。
bits(key[, shape, dtype]) 以无符号整数的形式采样均匀比特。
categorical(key, logits[, axis, shape]) 从分类分布中采样随机值。
cauchy(key[, shape, dtype]) 采样给定形状和浮点数数据类型的柯西分布随机值。
chisquare(key, df[, shape, dtype]) 采样给定形状和浮点数数据类型的卡方分布随机值。
choice(key, a[, shape, replace, p, axis]) 从给定数组中生成随机样本。
dirichlet(key, alpha[, shape, dtype]) 采样给定形状和浮点数数据类型的狄利克雷分布随机值。
double_sided_maxwell(key, loc, scale[, …]) 从双边 Maxwell 分布中采样。
exponential(key[, shape, dtype]) 采样给定形状和浮点数数据类型的指数分布随机值。
f(key, dfnum, dfden[, shape, dtype]) 采样给定形状和浮点数数据类型的 F 分布随机值。
gamma(key, a[, shape, dtype]) 采样给定形状和浮点数数据类型的伽马分布随机值。
generalized_normal(key, p[, shape, dtype]) 从广义正态分布中采样。
geometric(key, p[, shape, dtype]) 采样给定形状和浮点数数据类型的几何分布随机值。
gumbel(key[, shape, dtype]) 采样给定形状和浮点数数据类型的 Gumbel 分布随机值。
laplace(key[, shape, dtype]) 采样给定形状和浮点数数据类型的拉普拉斯分布随机值。
loggamma(key, a[, shape, dtype]) 采样给定形状和浮点数数据类型的对数伽马分布随机值。
logistic(key[, shape, dtype]) 采样给定形状和浮点数数据类型的 logistic 随机值。
lognormal(key[, sigma, shape, dtype]) 采样给定形状和浮点数数据类型的对数正态分布随机值。
maxwell(key[, shape, dtype]) 从单边 Maxwell 分布中采样。
multivariate_normal(key, mean, cov[, shape, …]) 采样给定均值和协方差的多变量正态分布随机值。
normal(key[, shape, dtype]) 采样给定形状和浮点数数据类型的标准正态分布随机值。
orthogonal(key, n[, shape, dtype]) 从正交群 O(n) 中均匀采样。
pareto(key, b[, shape, dtype]) 采样给定形状和浮点数数据类型的帕累托分布随机值。
permutation(key, x[, axis, independent]) 返回随机排列的数组或范围。
poisson(key, lam[, shape, dtype]) 采样给定形状和整数数据类型的泊松分布随机值。
rademacher(key[, shape, dtype]) 从 Rademacher 分布中采样。
randint(key, shape, minval, maxval[, dtype]) 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机整数值。
[rayleigh(key, scale[, shape, dtype]) 用给定的形状和浮点数数据类型示例瑞利随机值。
t(key, df[, shape, dtype]) 用给定的形状和浮点数数据类型示例学生 t 分布随机值。
triangular(key, left, mode, right[, shape, …]) 用给定的形状和浮点数数据类型示例三角形随机值。
truncated_normal(key, lower, upper[, shape, …]) 用给定的形状和数据类型示例截断标准正态随机值。
uniform(key[, shape, dtype, minval, maxval]) 用给定的形状和数据类型在[minval, maxval)范围内示例均匀随机值。
[wald(key, mean[, shape, dtype]) 用给定的形状和浮点数数据类型示例瓦尔德随机值。
weibull_min(key, scale, concentration[, …]) 从威布尔分布中采样。


JAX 中文文档(十四)(4)https://developer.aliyun.com/article/1559759

相关文章
|
机器学习/深度学习 自然语言处理 并行计算
大模型开发:什么是Transformer架构及其重要性?
Transformer模型革新了NLP,以其高效的并行计算和自注意力机制解决了长距离依赖问题。从机器翻译到各种NLP任务,Transformer展现出卓越性能,其编码器-解码器结构结合自注意力层和前馈网络,实现高效训练。此架构已成为领域内重要里程碑。
1124 3
|
IDE Java Maven
Spring Boot之如何解决Maven依赖冲突Maven Helper 安装使用
Spring Boot之如何解决Maven依赖冲突Maven Helper 安装使用
736 2
|
12月前
|
缓存 IDE 调度
【HarmonyOS Next之旅】基于ArkTS开发(一) -> Ability开发一
本文介绍了HarmonyOS中的FA模型及其开发相关内容,包括PageAbility与ServiceAbility的开发方法。FA模型下的Ability分为多种类型,如PageAbility(带UI,用户可见可交互)、ServiceAbility(无UI,在后台提供服务)等。文章详细阐述了PageAbility的生命周期、启动模式及接口使用,并通过代码示例展示了如何启动本地PageAbility和重写生命周期函数。
355 12
|
10月前
|
存储 数据采集 API
小红书笔记详情API深度解析与实战指南(2025年最新版)
本文深入解析小红书开放平台笔记详情API的进阶使用与合规策略,涵盖接口升级、数据维度扩展、调用优化等内容,并提供Python调用示例及数据清洗存储方案。结合电商导购、舆情监控等实战场景,助力开发者高效获取并应用内容资产,同时强调数据隐私与平台政策合规要点,帮助构建稳定、安全的数据应用体系。
|
API 异构计算 索引
JAX 中文文档(十四)(2)
JAX 中文文档(十四)
254 0
|
Python
Python 进度条 tqdm模块
Python 进度条 tqdm模块
286 0
|
存储 前端开发 JavaScript
React 文件上传组件 File Upload
本文介绍了如何在 React 中实现文件上传组件,包括基本的概念、实现步骤、常见问题及解决方案。通过 `&lt;input type=&quot;file&quot;&gt;` 元素选择文件,使用 `fetch` 发送请求,处理文件类型和大小限制,以及多文件上传和进度条显示等高级功能,帮助开发者构建高效、可靠的文件上传组件。
1077 3
|
机器学习/深度学习 并行计算 PyTorch
深度学习环境搭建笔记(一):detectron2安装过程
这篇博客文章详细介绍了在Windows环境下,使用CUDA 10.2配置深度学习环境,并安装detectron2库的步骤,包括安装Python、pycocotools、Torch和Torchvision、fvcore,以及对Detectron2和PyTorch代码的修改。
3381 1
深度学习环境搭建笔记(一):detectron2安装过程
|
安全 应用服务中间件 Linux
判断一个网站是否使用HTTPS协议
判断一个网站是否使用HTTPS协议
3155 4
|
C语言
【数据结构】二叉树(c语言)(附源码)
本文介绍了如何使用链式结构实现二叉树的基本功能,包括前序、中序、后序和层序遍历,统计节点个数和树的高度,查找节点,判断是否为完全二叉树,以及销毁二叉树。通过手动创建一棵二叉树,详细讲解了每个功能的实现方法和代码示例,帮助读者深入理解递归和数据结构的应用。
1736 9

热门文章

最新文章