JAX 中文文档(十四)(4)https://developer.aliyun.com/article/1559759
jax.debug 模块
运行时值调试实用工具
jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。
callback (callback, *args[, ordered]) |
调用可分阶段的 Python 回调函数。 |
print (fmt, *args[, ordered]) |
打印值,并在 JAX 函数中工作。 |
breakpoint (*[, backend, filter_frames, …]) |
在程序中某一点设置断点。 |
调试分片实用工具
能够在分段函数内(和外部)检查和可视化数组分片的函数。
inspect_array_sharding (value, *, callback) |
在 JIT 编译函数内部启用检查数组分片。 |
visualize_array_sharding (arr, **kwargs) |
可视化数组的分片。 |
visualize_sharding (shape, sharding, *[, …]) |
使用 rich 可视化 Sharding 。 |
jax.dlpack 模块
from_dlpack (external_array[, device, copy]) |
返回一个 DLPack 张量的 Array 表示形式。 |
to_dlpack (x[, stream, src_device, …]) |
返回一个封装了 Array x 的 DLPack 张量。 |
jax.distributed 模块
initialize ([coordinator_address, …]) |
初始化 JAX 分布式系统。 |
shutdown () |
关闭分布式系统。 |
jax.dtypes 模块
bfloat16 |
bfloat16 浮点数值 |
canonicalize_dtype (dtype[, allow_extended_dtype]) |
根据config.x64_enabled 配置将 dtype 转换为规范的 dtype。 |
float0 |
对应于相同名称的标量类型和 dtype 的 DType 类。 |
issubdtype (a, b) |
如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。 |
prng_key () |
PRNG Key dtypes 的标量类。 |
result_type (*args[, return_weak_type_flag]) |
方便函数,用于应用 JAX 参数 dtype 提升。 |
scalar_type_of (x) |
返回与 JAX 值关联的标量类型。 |
jax.flatten_util 模块
函数列表
- | ravel_pytree (pytree) |
将一个数组的 pytree 展平(压缩)为一个 1D 数组。 |
jax.image 模块
图像操作函数。
更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX。
图像操作函数
resize (image, shape, method[, antialias, …]) |
图像调整大小。 |
scale_and_translate (image, shape, …[, …]) |
对图像应用缩放和平移。 |
参数类
class jax.image.ResizeMethod(value)
图像调整大小方法。
可能的取值包括:
NEAREST:
最近邻插值。
LINEAR:
线性插值。
LANCZOS3:
Lanczos 重采样,使用半径为 3 的核。
LANCZOS5:
Lanczos 重采样,使用半径为 5 的核。
CUBIC:
三次插值,使用 Keys 三次核。
jax.nn 模块
jax.nn.initializers
模块
神经网络库常见函数。
激活函数
relu |
线性整流单元激活函数。 |
relu6 |
线性整流单元 6 激活函数。 |
sigmoid (x) |
Sigmoid 激活函数。 |
softplus (x) |
Softplus 激活函数。 |
sparse_plus (x) |
稀疏加法函数。 |
sparse_sigmoid (x) |
稀疏 Sigmoid 激活函数。 |
soft_sign (x) |
Soft-sign 激活函数。 |
silu (x) |
SiLU(又称 swish)激活函数。 |
swish (x) |
SiLU(又称 swish)激活函数。 |
log_sigmoid (x) |
对数 Sigmoid 激活函数。 |
leaky_relu (x[, negative_slope]) |
泄漏整流线性单元激活函数。 |
hard_sigmoid (x) |
硬 Sigmoid 激活函数。 |
hard_silu (x) |
硬 SiLU(swish)激活函数。 |
hard_swish (x) |
硬 SiLU(swish)激活函数。 |
hard_tanh (x) |
硬\tanh 激活函数。 |
elu (x[, alpha]) |
指数线性单元激活函数。 |
celu (x[, alpha]) |
连续可微的指数线性单元激活函数。 |
selu (x) |
缩放的指数线性单元激活函数。 |
gelu (x[, approximate]) |
高斯误差线性单元激活函数。 |
glu (x[, axis]) |
门控线性单元激活函数。 |
squareplus (x[, b]) |
Squareplus 激活函数。 |
mish (x) |
Mish 激活函数。 |
其他函数
softmax (x[, axis, where, initial]) |
Softmax 函数。 |
log_softmax (x[, axis, where, initial]) |
对数 Softmax 函数。 |
logsumexp () |
对数-总和-指数归约。 |
standardize (x[, axis, mean, variance, …]) |
通过减去mean 并除以(\sqrt{\mathrm{variance}})来标准化数组。 |
one_hot (x, num_classes, *[, dtype, axis]) |
对给定索引进行 One-hot 编码。 |
jax.nn.initializers 模块
与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器
该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。
初始化器是一个函数,接受三个参数:(key, shape, dtype)
,并返回一个具有形状shape
和数据类型dtype
的数组。参数key
是一个 PRNG 密钥(例如来自jax.random.key()
),用于生成初始化数组的随机数。
constant (value[, dtype]) |
构建一个返回常数值数组的初始化器。 |
delta_orthogonal ([scale, column_axis, dtype]) |
构建一个用于增量正交核的初始化器。 |
glorot_normal ([in_axis, out_axis, …]) |
构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。 |
glorot_uniform ([in_axis, out_axis, …]) |
构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。 |
he_normal ([in_axis, out_axis, batch_axis, dtype]) |
构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。 |
he_uniform ([in_axis, out_axis, batch_axis, …]) |
构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。 |
lecun_normal ([in_axis, out_axis, …]) |
构建一个 Lecun 正态初始化器。 |
lecun_uniform ([in_axis, out_axis, …]) |
构建一个 Lecun 均匀初始化器。 |
normal ([stddev, dtype]) |
构建一个返回实数正态分布随机数组的初始化器。 |
ones (key, shape[, dtype]) |
返回一个填充为一的常数数组的初始化器。 |
orthogonal ([scale, column_axis, dtype]) |
构建一个返回均匀分布正交矩阵的初始化器。 |
truncated_normal ([stddev, dtype, lower, upper]) |
构建一个返回截断正态分布随机数组的初始化器。 |
uniform ([scale, dtype]) |
构建一个返回实数均匀分布随机数组的初始化器。 |
variance_scaling (scale, mode, distribution) |
初始化器,根据权重张量的形状调整其尺度。 |
zeros (key, shape[, dtype]) |
返回一个填充零的常数数组的初始化器。 |
jax.ops 模块
段落约简运算符
| segment_max
(data, segment_ids[, …]) | 计算数组段内的最大值。 |
函数 jax.ops.index_update 、jax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。 |
segment_min (data, segment_ids[, …]) |
segment_prod (data, segment_ids[, …]) |
segment_sum (data, segment_ids[, …]) |
jax.profiler 模块
跟踪和时间分析
描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。
start_server (port) |
在指定端口启动分析器服务器。 |
start_trace (log_dir[, create_perfetto_link, …]) |
启动性能分析跟踪。 |
stop_trace () |
停止当前正在运行的性能分析跟踪。 |
trace (log_dir[, create_perfetto_link, …]) |
上下文管理器,用于进行性能分析跟踪。 |
annotate_function (func[, name]) |
生成函数执行的跟踪事件的装饰器。 |
TraceAnnotation |
在分析器中生成跟踪事件的上下文管理器。 |
StepTraceAnnotation (name, **kwargs) |
在分析器中生成步骤跟踪事件的上下文管理器。 |
设备内存分析
请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。
device_memory_profile ([backend]) |
捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。 |
save_device_memory_profile (filename[, backend]) |
收集设备内存使用情况,并将其写入文件。 |
jax.stages 模块
接口到编译执行过程的各个阶段。
JAX 转换,例如jax.jit
和jax.pmap
,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。
有关更多信息,请参阅AOT walkthrough。
类
class jax.stages.Wrapped(*args, **kwargs)
一个准备好进行追踪、降阶和编译的函数。
此协议反映了诸如jax.jit
之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。
__call__(*args, **kwargs)
执行包装的函数,根据需要进行降阶和编译。
lower(*args, **kwargs)
明确为给定的参数降阶此函数。
一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。
返回:
一个Lowered
实例,表示降阶。
返回类型:
降阶
trace(*args, **kwargs)
明确为给定的参数追踪此函数。
一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。
返回:
一个Traced
实例,表示追踪。
返回类型:
追踪
class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)
降阶一个根据参数类型和值特化的函数。
降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()
、pmap()
等)中降阶计算的属性。
参数:
as_text(dialect=None)
此降阶的人类可读文本表示。
旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。
参数:
方言(str | 无) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)
返回类型:
compile(compiler_options=None)
编译,并返回相应的Compiled
实例。
参数:
compiler_options (dict[str, str | bool] | None)
返回类型:
Compiled
compiler_ir(dialect=None)
这种降低的任意对象表示。
旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
参数:
dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)
返回类型:
Any | None
cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
Any | None
property in_tree: PyTreeDef
一对(位置参数、关键字参数)的树结构。
class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)
编译后的函数专门针对类型/值进行了优化表示。
编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。
参数:
- args_info (Any)
- out_tree (PyTreeDef)
__call__(*args, **kwargs)
将自身作为函数调用。
as_text()
这是可执行文件的人类可读文本表示。
旨在可视化和调试。这不是有效的也不是可靠的序列化。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
str | None
cost_analysis()
执行成本估算的摘要。
旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。
如果不可用,则返回None
,例如基于后端、编译器或运行时。
返回类型:
Any | None
property in_tree: PyTreeDef
(位置参数,关键字参数) 的树结构。
memory_analysis()
估计内存需求的摘要。
用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。
返回 None
如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None
runtime_executable()
此可执行对象的任意对象表示。
用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。
返回 None
如果不可用,例如基于后端、编译器或运行时。
返回类型:
任意 | None