JAX 中文文档(十四)(5)

简介: JAX 中文文档(十四)

JAX 中文文档(十四)(4)https://developer.aliyun.com/article/1559759


jax.debug 模块

原文:jax.readthedocs.io/en/latest/jax.debug.html

运行时值调试实用工具

jax.debug.print 和 jax.debug.breakpoint 描述了如何利用 JAX 的运行时值调试功能。

callback(callback, *args[, ordered]) 调用可分阶段的 Python 回调函数。
print(fmt, *args[, ordered]) 打印值,并在 JAX 函数中工作。
breakpoint(*[, backend, filter_frames, …]) 在程序中某一点设置断点。

调试分片实用工具

能够在分段函数内(和外部)检查和可视化数组分片的函数。

inspect_array_sharding(value, *, callback) 在 JIT 编译函数内部启用检查数组分片。
visualize_array_sharding(arr, **kwargs) 可视化数组的分片。
visualize_sharding(shape, sharding, *[, …]) 使用 rich 可视化 Sharding

jax.dlpack 模块

原文:jax.readthedocs.io/en/latest/jax.dlpack.html

from_dlpack(external_array[, device, copy]) 返回一个 DLPack 张量的 Array 表示形式。
to_dlpack(x[, stream, src_device, …]) 返回一个封装了 Array x 的 DLPack 张量。

jax.distributed 模块

原文:jax.readthedocs.io/en/latest/jax.distributed.html

initialize([coordinator_address, …]) 初始化 JAX 分布式系统。
shutdown() 关闭分布式系统。

jax.dtypes 模块

原文:jax.readthedocs.io/en/latest/jax.dtypes.html

bfloat16 bfloat16 浮点数值
canonicalize_dtype(dtype[, allow_extended_dtype]) 根据config.x64_enabled配置将 dtype 转换为规范的 dtype。
float0 对应于相同名称的标量类型和 dtype 的 DType 类。
issubdtype(a, b) 如果第一个参数是类型代码在类型层次结构中较低/相等,则返回 True。
prng_key() PRNG Key dtypes 的标量类。
result_type(*args[, return_weak_type_flag]) 方便函数,用于应用 JAX 参数 dtype 提升。
scalar_type_of(x) 返回与 JAX 值关联的标量类型。

jax.flatten_util 模块

原文:jax.readthedocs.io/en/latest/jax.flatten_util.html

函数列表

- ravel_pytree(pytree) 将一个数组的 pytree 展平(压缩)为一个 1D 数组。

jax.image 模块

原文:jax.readthedocs.io/en/latest/jax.image.html

图像操作函数。

更多的图像操作函数可以在建立在 JAX 之上的库中找到,例如 PIX

图像操作函数

resize(image, shape, method[, antialias, …]) 图像调整大小。
scale_and_translate(image, shape, …[, …]) 对图像应用缩放和平移。

参数类

class jax.image.ResizeMethod(value)

图像调整大小方法。

可能的取值包括:

NEAREST:

最近邻插值。

LINEAR:

线性插值

LANCZOS3:

Lanczos 重采样,使用半径为 3 的核。

LANCZOS5:

Lanczos 重采样,使用半径为 5 的核。

CUBIC:

三次插值,使用 Keys 三次核。

jax.nn 模块

原文:jax.readthedocs.io/en/latest/jax.nn.html

  • jax.nn.initializers 模块

神经网络库常见函数。

激活函数

relu 线性整流单元激活函数。
relu6 线性整流单元 6 激活函数。
sigmoid(x) Sigmoid 激活函数。
softplus(x) Softplus 激活函数。
sparse_plus(x) 稀疏加法函数。
sparse_sigmoid(x) 稀疏 Sigmoid 激活函数。
soft_sign(x) Soft-sign 激活函数。
silu(x) SiLU(又称 swish)激活函数。
swish(x) SiLU(又称 swish)激活函数。
log_sigmoid(x) 对数 Sigmoid 激活函数。
leaky_relu(x[, negative_slope]) 泄漏整流线性单元激活函数。
hard_sigmoid(x) 硬 Sigmoid 激活函数。
hard_silu(x) 硬 SiLU(swish)激活函数。
hard_swish(x) 硬 SiLU(swish)激活函数。
hard_tanh(x) 硬\tanh 激活函数。
elu(x[, alpha]) 指数线性单元激活函数。
celu(x[, alpha]) 连续可微的指数线性单元激活函数。
selu(x) 缩放的指数线性单元激活函数。
gelu(x[, approximate]) 高斯误差线性单元激活函数。
glu(x[, axis]) 门控线性单元激活函数。
squareplus(x[, b]) Squareplus 激活函数。
mish(x) Mish 激活函数。

其他函数

softmax(x[, axis, where, initial]) Softmax 函数。
log_softmax(x[, axis, where, initial]) 对数 Softmax 函数。
logsumexp() 对数-总和-指数归约。
standardize(x[, axis, mean, variance, …]) 通过减去mean并除以(\sqrt{\mathrm{variance}})来标准化数组。
one_hot(x, num_classes, *[, dtype, axis]) 对给定索引进行 One-hot 编码。

jax.nn.initializers 模块

原文:jax.readthedocs.io/en/latest/jax.nn.initializers.html

与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器

该模块提供了与 Keras 和 Sonnet 中定义一致的常见神经网络层初始化器。

初始化器是一个函数,接受三个参数:(key, shape, dtype),并返回一个具有形状shape和数据类型dtype的数组。参数key是一个 PRNG 密钥(例如来自jax.random.key()),用于生成初始化数组的随机数。

constant(value[, dtype]) 构建一个返回常数值数组的初始化器。
delta_orthogonal([scale, column_axis, dtype]) 构建一个用于增量正交核的初始化器。
glorot_normal([in_axis, out_axis, …]) 构建一个 Glorot 正态初始化器(又称 Xavier 正态初始化器)。
glorot_uniform([in_axis, out_axis, …]) 构建一个 Glorot 均匀初始化器(又称 Xavier 均匀初始化器)。
he_normal([in_axis, out_axis, batch_axis, dtype]) 构建一个 He 正态初始化器(又称 Kaiming 正态初始化器)。
he_uniform([in_axis, out_axis, batch_axis, …]) 构建一个 He 均匀初始化器(又称 Kaiming 均匀初始化器)。
lecun_normal([in_axis, out_axis, …]) 构建一个 Lecun 正态初始化器。
lecun_uniform([in_axis, out_axis, …]) 构建一个 Lecun 均匀初始化器。
normal([stddev, dtype]) 构建一个返回实数正态分布随机数组的初始化器。
ones(key, shape[, dtype]) 返回一个填充为一的常数数组的初始化器。
orthogonal([scale, column_axis, dtype]) 构建一个返回均匀分布正交矩阵的初始化器。
truncated_normal([stddev, dtype, lower, upper]) 构建一个返回截断正态分布随机数组的初始化器。
uniform([scale, dtype]) 构建一个返回实数均匀分布随机数组的初始化器。
variance_scaling(scale, mode, distribution) 初始化器,根据权重张量的形状调整其尺度。
zeros(key, shape[, dtype]) 返回一个填充零的常数数组的初始化器。

jax.ops 模块

原文:jax.readthedocs.io/en/latest/jax.ops.html

段落约简运算符

| segment_max(data, segment_ids[, …]) | 计算数组段内的最大值。 |

函数 jax.ops.index_updatejax.ops.index_add 等已在 JAX 0.2.22 中弃用,并已移除。请改用 JAX 数组上的 jax.numpy.ndarray.at 属性。
segment_min(data, segment_ids[, …])
segment_prod(data, segment_ids[, …])
segment_sum(data, segment_ids[, …])

jax.profiler 模块

原文:jax.readthedocs.io/en/latest/jax.profiler.html

跟踪和时间分析

描述了如何利用 JAX 的跟踪和时间分析功能进行程序性能分析。

start_server(port) 在指定端口启动分析器服务器。
start_trace(log_dir[, create_perfetto_link, …]) 启动性能分析跟踪。
stop_trace() 停止当前正在运行的性能分析跟踪。
trace(log_dir[, create_perfetto_link, …]) 上下文管理器,用于进行性能分析跟踪。
annotate_function(func[, name]) 生成函数执行的跟踪事件的装饰器。
TraceAnnotation 在分析器中生成跟踪事件的上下文管理器。
StepTraceAnnotation(name, **kwargs) 在分析器中生成步骤跟踪事件的上下文管理器。

设备内存分析

请参阅设备内存分析,了解 JAX 的设备内存分析功能简介。

device_memory_profile([backend]) 捕获 JAX 设备内存使用情况,格式为 pprof 协议缓冲区。
save_device_memory_profile(filename[, backend]) 收集设备内存使用情况,并将其写入文件。

jax.stages 模块

原文:jax.readthedocs.io/en/latest/jax.stages.html

接口到编译执行过程的各个阶段。

JAX 转换,例如jax.jitjax.pmap,也支持一种通用的显式降阶和预编译执行 ahead of time 的方式。 该模块定义了代表这一过程各个阶段的类型。

有关更多信息,请参阅AOT walkthrough

class jax.stages.Wrapped(*args, **kwargs)

一个准备好进行追踪、降阶和编译的函数。

此协议反映了诸如jax.jit之类的函数的输出。 调用它会导致 JIT(即时)降阶、编译和执行。 它也可以在编译之前明确降阶,并在执行之前编译结果。

__call__(*args, **kwargs)

执行包装的函数,根据需要进行降阶和编译。

lower(*args, **kwargs)

明确为给定的参数降阶此函数。

一个降阶函数被从 Python 阶段化,并翻译为编译器的输入语言,可能以依赖于后端的方式。 它已准备好进行编译,但尚未编译。

返回:

一个Lowered实例,表示降阶。

返回类型:

降阶

trace(*args, **kwargs)

明确为给定的参数追踪此函数。

一个追踪函数被从 Python 阶段化,并翻译为一个 jaxpr。 它已准备好进行降阶,但尚未降阶。

返回:

一个Traced实例,表示追踪。

返回类型:

追踪

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)

降阶一个根据参数类型和值特化的函数。

降阶是一种准备好进行编译的计算。 此类将降阶与稍后编译和执行所需的剩余信息一起携带。 它还提供了一个通用的 API,用于查询 JAX 各种降阶路径(jit()pmap()等)中降阶计算的属性。

参数:

  • 降阶XlaLowering
  • args_infoAny
  • out_treePyTreeDef
  • no_kwargsbool
as_text(dialect=None)

此降阶的人类可读文本表示。

旨在可视化和调试目的。 这不必是有效的也不一定可靠的序列化。 它直接传递给外部调用者。

参数:

方言str | ) – 可选字符串,指定一个降阶方言(例如,“stablehlo”)

返回类型:

str

compile(compiler_options=None)

编译,并返回相应的Compiled实例。

参数:

compiler_options (dict[str, str | bool] | None)

返回类型:

Compiled

compiler_ir(dialect=None)

这种降低的任意对象表示。

旨在调试目的。这不是有效的也不是可靠的序列化。输出在不同调用之间没有一致性的保证。

如果不可用,则返回None,例如基于后端、编译器或运行时。

参数:

dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)

返回类型:

Any | None

cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: PyTreeDef

一对(位置参数、关键字参数)的树结构。

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)

编译后的函数专门针对类型/值进行了优化表示。

编译计算与可执行文件相关联,并提供执行所需的剩余信息。它还为查询 JAX 的各种编译路径和后端中编译计算属性提供了一个共同的 API。

参数:

  • args_info (Any)
  • out_tree (PyTreeDef)
__call__(*args, **kwargs)

将自身作为函数调用。

as_text()

这是可执行文件的人类可读文本表示。

旨在可视化和调试。这不是有效的也不是可靠的序列化。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

str | None

cost_analysis()

执行成本估算的摘要。

旨在可视化和调试。此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如,带有数值叶的嵌套字典、列表和元组)。然而,它的结构可以是任意的:在 JAX 和 jaxlib 的不同版本甚至调用之间可能不一致。

如果不可用,则返回None,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: PyTreeDef

(位置参数,关键字参数) 的树结构。

memory_analysis()

估计内存需求的摘要。

用于可视化和调试目的。由此输出的对象是一些简单的数据结构,可以轻松打印或序列化(例如嵌套的字典、列表和具有数字叶子的元组)。然而,其结构可以是任意的:在 JAX 和 jaxlib 的不同版本之间,甚至在不同调用之间可能是不一致的。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

runtime_executable()

此可执行对象的任意对象表示。

用于调试目的。这不是有效也不是可靠的序列化。输出不能保证在不同调用之间的一致性。

返回 None 如果不可用,例如基于后端、编译器或运行时。

返回类型:

任意 | None

相关文章
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
35 1
|
4月前
|
API 异构计算 索引
JAX 中文文档(十四)(2)
JAX 中文文档(十四)
44 0
|
4月前
|
资源调度 算法 安全
JAX 中文文档(十四)(3)
JAX 中文文档(十四)
45 0
|
4月前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
32 0
|
4月前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
33 0