JAX 中文文档(十六)(2)https://developer.aliyun.com/article/1559727
jaxlib 0.4.14(2023 年 7 月 27 日)
- 弃用
- 根据 https://jax.readthedocs.io/en/latest/deprecation.html,不再支持 Python 3.8。
jax 0.4.13(2023 年 6 月 22 日)
- 更改
jax.jit
现在允许将None
传递给in_shardings
和out_shardings
。语义如下:
- 对于
in_shardings
,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings
,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
jax.experimental.pjit.pjit
也允许将None
传递给in_shardings
和out_shardings
。语义如下:
- 如果未提供网格上下文管理器,则 JAX 可自由选择所需的分片方式。
- 对于
in_shardings
,JAX 将其标记为复制,但这种行为可能会在将来更改。 - 对于
out_shardings
,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
- 如果提供了网格上下文管理器,
None
将意味着该值将在网格的所有设备上复制。
- Executable.cost_analysis() 在 Cloud TPU 上可用
- 如果正在使用非允许的
jaxlib
插件,则添加了警告。 - 添加了
jax.tree_util.tree_leaves_with_path
。 None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效输入。如果您希望复制您的输入,请使用jax.sharding.PartitionSpec()
。
- Bug 修复
- 在 CUDA 12 发布中修复了错误的轮子名称(#16362);正确的轮子名称为
cudnn89
而不是cudnn88
。
- 弃用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
参数已被弃用,推荐使用新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13(2023 年 6 月 22 日)
- 更改
- 将 Windows 仅 CPU 轮子添加到
jaxlib
Pypi 发布中。
- Bug 修复
__cuda_array_interface__
在之前的 jaxlib 版本中出现问题,现已修复(#16440)。- 并行 CUDA 内核跟踪现在默认启用于 NVIDIA GPU。
jax 0.4.12(2023 年 6 月 8 日)
- 更改
- 弃用
jax.abstract_arrays
及其内容已被弃用。请参阅:mod:jax.core
中的相关功能。jax.numpy.alltrue
:使用jax.numpy.all
。这遵循了 NumPy 版本 1.25.0 中numpy.alltrue
的弃用。jax.numpy.sometrue
:使用jax.numpy.any
。这遵循了 NumPy 版本 1.25.0 中numpy.sometrue
的弃用。jax.numpy.product
:使用jax.numpy.prod
。这遵循了 NumPy 版本 1.25.0 中numpy.product
的弃用。jax.numpy.cumproduct
:使用jax.numpy.cumprod
。这遵循了 NumPy 版本 1.25.0 中numpy.cumproduct
的弃用。jax.sharding.OpShardingSharding
已被移除,因为它已经弃用了 3 个月。
jaxlib 0.4.12 (2023 年 6 月 8 日)
- 变更
- 包含了 Hopper(SM 版本 9.0+)GPU 的 PTX/SASS。之前的 jaxlib 版本应该可以在 Hopper 上工作,但第一次执行 JAX 操作时可能会有较长的 JIT 编译延迟。
- Bug 修复
- 修复了在 Python 3.11 下 JAX 生成的 Python 回溯中源代码行信息不正确的问题。
- 修复了在 JAX 生成的 Python 回溯的帧中打印本地变量时崩溃的问题(#16027)。
jax 0.4.11 (2023 年 5 月 31 日)
- 弃用
- 根据 API 兼容性政策,在 3 个月的弃用期后,已移除以下 API:
jax.experimental.PartitionSpec
:使用jax.sharding.PartitionSpec
。jax.experimental.maps.Mesh
:使用jax.sharding.Mesh
。jax.experimental.pjit.NamedSharding
:使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:使用jax.sharding.PartitionSpec
。jax.experimental.pjit.FROM_GDA
。请将分片的jax.Array
对象作为输入传递,并删除pjit
的可选in_shardings
参数。jax.interpreters.pxla.PartitionSpec
:使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:使用jax.sharding.Mesh
。jax.interpreters.xla.Buffer
:使用jax.Array
。jax.interpreters.xla.Device
:使用jax.Device
。jax.interpreters.xla.DeviceArray
:使用jax.Array
。jax.interpreters.xla.device_put
:使用jax.device_put
。jax.interpreters.xla.xla_call_p
:使用jax.experimental.pjit.pjit_p
。with_sharding_constraint
的axis_resources
参数已被移除。请改用shardings
。
jaxlib 0.4.11 (2023 年 5 月 31 日)
- 变更
- 向
Device
添加了memory_stats()
方法。如果支持,它将返回一个包含字符串统计名称和整数值的字典,例如"bytes_in_use"
,如果平台不支持内存统计,则返回 None。具体的统计数据可能因平台而异。目前仅在 Cloud TPU 上实现。 - 重新添加了对 CPU 设备上 Python 缓冲协议(
memoryview
)的支持。
jax 0.4.10 (2023 年 5 月 11 日)
jaxlib 0.4.10 (2023 年 5 月 11 日)
- 变更
- 修复了阻止上一个版本在 Mac M1 上运行的
'apple-m1' is not a recognized processor for this target (ignoring processor)
问题。
jax 0.4.9 (2023 年 5 月 9 日)
- 变更
experimental_cpp_jit
、experimental_cpp_pjit
和experimental_cpp_pmap
标志已被移除。它们现在始终开启。- TPU 上奇异值分解(SVD)的准确性已经提高(需要 jaxlib 0.4.9)。
- 废弃功能
jax.experimental.gda_serialization
已废弃,并已重命名为jax.experimental.array_serialization
。请更改您的导入以使用jax.experimental.array_serialization
。pjit
的in_axis_resources
和out_axis_resources
参数已废弃。请分别使用in_shardings
和out_shardings
。- 函数
jax.numpy.msort
已被移除。自 JAX v0.4.1 起已被废弃。请使用jnp.sort(a, axis=0)
代替。 in_parts
和out_parts
参数已从jax.xla_computation
中移除,因为它们只与sharded_jit
一起使用,并且sharded_jit
已不再使用。- 自从很久以来未被使用,
instantiate_const_outputs
参数已从jax.xla_computation
中移除。
jaxlib 0.4.9(2023 年 5 月 9 日)
jax 0.4.8(2023 年 3 月 29 日)
- 破坏性变更
- Cloud TPU 运行时的一个重要组件已升级。这使得以下新功能在 Cloud TPU 上可用:
jax.debug.print()
、jax.debug.callback()
和jax.debug.breakpoint()
现在在 Cloud TPU 上可用。- 自动 TPU 内存碎片整理
- 在新的运行时组件上,不再支持
jax.experimental.host_callback()
在 Cloud TPU 上的使用。如果新的jax.debug
API 不能满足您的需求,请在JAX 问题跟踪器上提出问题。
旧的运行时组件将通过设置环境变量JAX_USE_PJRT_C_API_ON_TPU=false
至少在接下来的三个月内可用。如果您发现需要出于任何原因禁用新的运行时,请在JAX 问题跟踪器上告知我们。
- 变更
- 最低 jaxlib 版本已从 0.4.6 提升至 0.4.7。
- 废弃功能
- 支持 CUDA 11.4 已被移除。JAX GPU 版本仅支持 CUDA 11.8 和 CUDA 12。如果使用旧版 CUDA 构建 jaxlib 可能会正常工作。
pmap
的global_arg_shapes
参数仅适用于sharded_jit
,已从pmap
中移除。请迁移到pjit
并从pmap
中移除global_arg_shapes
。
jax 0.4.7(2023 年 3 月 27 日)
- 变更
- 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration,不再支持禁用
jax.config.jax_array
。 - 不再支持禁用
jax.config.jax_jit_pjit_api_merge
。 jax.experimental.jax2tf.convert()
现在支持native_serialization
参数,使用 JAX 的本机降级到 StableHLO 以获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降级到 TensorFlow 操作。这简化了内部操作,并增加了您序列化内容与 JAX 本机语义匹配的信心。详见文档。作为这一变更的一部分,配置标志--jax2tf_default_experimental_native_lowering
已重命名为--jax2tf_native_serialization
。- JAX 现在依赖于
ml_dtypes
,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。 - JAX 现在要求使用 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。
- 弃用信息
- 类型
jax.numpy.DeviceArray
已弃用。请改用jax.Array
,它是其别名。 - 类型
jax.interpreters.pxla.ShardedDeviceArray
已弃用。请改用jax.Array
。 - 通过位置传递额外参数给
jax.numpy.ndarray.at()
已被弃用。例如,不要使用x.at[i].get(True)
,而是使用x.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
已被弃用。请使用jax.device_put
。jax.interpreters.pxla.device_put
已被弃用。请使用jax.device_put
。jax.experimental.pjit.FROM_GDA
已被弃用。请将分片的 jax.Arrays 作为输入,并移除 pjit 中的in_shardings
参数,因为它是可选的。
jaxlib 0.4.7(2023 年 3 月 27 日)
变更:
- jaxlib 现在依赖于
ml_dtypes
,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。
jax 0.4.6(2023 年 3 月 9 日)
- 变更
jax.tree_util
现在包含一组允许用户为其自定义 pytree 节点定义键的 API。
tree_flatten_with_path
可以展平树并返回每个叶子及其键路径。tree_map_with_path
可以映射一个接受键路径作为参数的函数。register_pytree_with_keys
用于注册自定义 pytree 节点中键路径和叶子的外观。keystr
用于漂亮地打印键路径。
jax2tf.call_tf()
现在有一个新参数output_shape_dtype
(默认为None
),可用于声明结果的输出形状和类型。这使得jax2tf.call_tf()
能够在形状多态性存在的情况下工作。(#14734)
- 弃用信息
jax.tree_util
中的旧键路径 API 已被弃用,并将在 2023 年 3 月 10 日后的 3 个月内移除:
register_keypaths
:请使用jax.tree_util.register_pytree_with_keys()
替代。AttributeKeyPathEntry
:请改用GetAttrKey
。GetitemKeyPathEntry
:请改用SequenceKey
或DictKey
。
jaxlib 0.4.6(2023 年 3 月 9 日)
jax 0.4.5(2023 年 3 月 2 日)
- 弃用信息
jax.sharding.OpShardingSharding
已重命名为jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
将在 2023 年 2 月 17 日后的 3 个月内移除。- 下列
jax.Array
方法已被弃用,并将在 2023 年 2 月 23 日后的 3 个月内移除:
jax.Array.broadcast
:请使用jax.lax.broadcast()
替代。jax.Array.broadcast_in_dim
:请使用jax.lax.broadcast_in_dim()
替代。jax.Array.split
:请使用jax.numpy.split()
替代。
jax 0.4.4(2023 年 2 月 16 日)
- 变更
jit
和pjit
的实现已合并。合并 jit 和 pjit 改变了 JAX 的内部实现,但不影响 JAX 的公共 API。之前,jit
是一种最终风格的原语。最终风格意味着尽可能延迟创建 jaxpr 并将变换堆叠在一起。随着jit
-pjit
实现的合并,jit
变成了一种初始风格的原语,这意味着我们尽早追踪到 jaxpr。更多信息请参见 autodidax 中的这一部分。转移到初始风格应该简化 JAX 的内部实现,并使得动态形状等功能的开发更加容易。你只能通过环境变量来禁用它,即os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。由于它影响到 JAX 的导入时机,因此必须通过环境变量禁用它,在导入 jax 之前就需要禁用它。with_sharding_constraint
的axis_resources
参数已弃用。请改用shardings
。如果你将其作为参数使用,则无需更改。如果你将其作为关键字参数使用,请改用shardings
。axis_resources
将在 2023 年 2 月 13 日后的 3 个月内删除。- 添加了
jax.typing
模块,用于 JAX 函数的类型注解工具。 - 下列名称已被弃用:
jax.xla.Device
和jax.interpreters.xla.Device
: 使用jax.Device
。jax.experimental.maps.Mesh
. 使用jax.sharding.Mesh
替代。jax.experimental.pjit.NamedSharding
: 使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
: 使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
: 使用jax.sharding.Mesh
。jax.interpreters.pxla.PartitionSpec
: 使用jax.sharding.PartitionSpec
。
- Breaking Changes
jax.numpy.sum
等的initial
参数现在要求是一个标量,与对应的 NumPy API 保持一致。以前的行为是对非标量initial
值进行广播,这是一个意外的实现细节(#14446)。
jaxlib 0.4.4(2023 年 2 月 16 日)
- Breaking changes
- 默认的
jaxlib
构建中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,可以通过使用 Kepler 支持的源码构建jaxlib
(通过build.py
的--cuda_compute_capabilities=sm_35
选项),不过请注意 CUDA 12 已完全停止支持 Kepler GPU。
jax 0.4.3(2023 年 2 月 8 日)
- Breaking changes
- 删除了
jax.scipy.linalg.polar_unitary()
,这是一个已弃用的 JAX 扩展到 scipy API 的函数。请改用jax.scipy.linalg.polar()
。
- Changes
- 添加了
jax.scipy.stats.rankdata()
。
jaxlib 0.4.3(2023 年 2 月 8 日)
jax.Array
现在具有非阻塞的is_ready()
方法,如果数组已准备就绪则返回True
(参见jax.block_until_ready()
)。
jax 0.4.2(2023 年 1 月 24 日)
- Breaking changes
- 删除了
jax.experimental.callback
- 在存在
jax2tf
形状多态性的情况下,对带有维度的操作进行了泛化处理,通过将符号维度转换为 JAX 数组来在更多场景下工作。现在,涉及符号维度和np.ndarray
的操作在结果用作形状值时可能会引发错误(#14106)。 - 现在,
jaxpr
对象在设置属性时会引发错误,以避免问题变异(#14102)
- 变更
jax2tf.call_tf()
现在有一个新参数has_side_effects
(默认为True
),可用于声明实例是否可以被 JAX 优化(如死代码消除)删除或复制(#13980)。- 为了支持
jax2tf
形状多态性的floordiv
和mod
,我们增加了更多支持。之前,存在符号维度时某些除法操作会导致错误(#14108)。
jaxlib 0.4.2(2023 年 1 月 24 日)
- 变更
- 设置
JAX_USE_PJRT_C_API_ON_TPU=1
可启用新的 Cloud TPU 运行时,具备自动设备内存碎片整理功能。
jax 0.4.1(2022 年 12 月 13 日)
- 变更
- 根据 JAX 的 Python 和 NumPy 版本支持政策,不再支持 Python 3.7。
- 我们引入了
jax.Array
,它是 JAX 中的统一数组类型,涵盖了DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
类型。jax.Array
类型有助于使并行成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一jit
和pjit
。jax.Array
已在 JAX 0.4 中默认启用,并对pjit
API 进行了一些破坏性更改。jax.Array 迁移指南 可帮助您将代码库迁移到jax.Array
。您还可以查看Distributed arrays and automatic parallelization 教程,以理解新概念。 PartitionSpec
和Mesh
现在不再处于实验阶段。新的 API 端点是jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已被弃用,并将在三个月内移除。with_sharding_constraint
的新公共端点是jax.lax.with_sharding_constraint
。- 如果与
jax.config
一起使用 ABSL 标志,那么在最初从 ABSL 标志填充 JAX 配置选项后,就不再读取或写入 ABSL 标志值。此更改改进了读取jax.config
选项的性能,这些选项在 JAX 中广泛使用。 jax2tf.call_tf
函数现在使用与嵌入 JAX 计算相同平台的第一个 TF 设备进行 TF 降级。以前,它使用的是 JAX 默认后端的第 0 个设备。- 现在,一些
jax.numpy
函数的参数已标记为仅限位置参数,与 NumPy 匹配。 jnp.msort
现已废弃,遵循 numpy 1.24 中np.msort
的废弃。它将在未来的版本中移除,符合 API 兼容性策略。可以用jnp.sort(a, axis=0)
替换。
jaxlib 0.4.1 (2022 年 12 月 13 日)
- 变更
- 支持 Python 3.7 已被放弃,符合 JAX 的 Python 和 NumPy 版本支持政策。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行为已更改,现在分配总 GPU 内存的 XX%来预分配,而不是以前使用当前可用 GPU 内存来计算预分配。有关更多详情,请参阅GPU memory allocation。- 废弃的方法
.block_host_until_ready()
已被移除。请改用.block_until_ready()
。
jax 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)
- 此版本已被撤回。
jax 0.3.25 (2022 年 11 月 15 日)
- 变更
jax.numpy.linalg.pinv()
现在支持hermitian
选项。jax.scipy.linalg.hessenberg()
现在仅在 CPU 上支持。需要 jaxlib > 0.3.24。- 新函数
jax.lax.linalg.hessenberg()
,jax.lax.linalg.tridiagonal()
和jax.lax.linalg.householder_product()
已添加。Householder 约简目前仅支持 CPU,三对角约简支持 CPU 和 GPU。 - 现在更经济地计算非方阵的
svd
和jax.numpy.linalg.pinv
的梯度。
- 突破性变更
- 删除了
jax_experimental_name_stack
配置选项。 - 将字符串
axis_names
参数转换为jax.experimental.maps.Mesh
构造函数的单例元组,而不是将字符串解包为字符轴名称序列。
jaxlib 0.3.25 (2022 年 11 月 15 日)
- 变更
- 添加了对 CPU 和 GPU 上三对角约简的支持。
- 添加了对 CPU 上上 Hessenberg 约简的支持。
- Bug 修复
- 修复了一个 bug,导致 JAX 捕获的回溯中的帧被错误地映射到 Python 3.10+下的源行。
jax 0.3.24 (2022 年 11 月 4 日)
- 变更
- JAX 导入速度应更快。现在我们懒惰地导入 scipy,这在 JAX 的导入时间中占据了相当大的部分。
- 设置环境变量
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可以用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过 1 秒的计算将被缓存。
- 添加了
jax.scipy.stats.mode()
。
- 如果在 TPU 上未指定顺序,则
pmap
的默认设备顺序现在与单进程作业的jax.devices()
匹配。以前两种排序不同,可能导致不必要的复制或内存不足错误。要求排序一致简化了问题。
- 突破性变更
jax.numpy.gradient()
现在像jax.numpy
中的大多数其他函数一样,禁止传递列表或元组以替代数组(#12958)。jax.numpy.linalg
和jax.numpy.fft
中的函数现在统一要求输入为数组形式:即不能使用列表和元组代替数组。部分属于#7737。
- 弃用
jax.sharding.MeshPspecSharding
已重命名为jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名称将在 3 个月内删除。
jaxlib 0.3.24(2022 年 11 月 4 日)
- 更改
- 现在在 CPU 上可以使用缓冲器捐赠。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖捐赠未实现的代码。
jax 0.3.23(2022 年 10 月 12 日)
- 更改
- 更新 Colab TPU 驱动程序版本以支持新的 jaxlib 发布。
jax 0.3.22(2022 年 10 月 11 日)
- 更改
- 在 TPU 初始化中添加
JAX_PLATFORMS=tpu,cpu
作为默认设置,因此如果无法初始化 TPU,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''
以覆盖此行为并自动选择可用的后端(原始默认),或设置JAX_PLATFORMS=cpu
以始终使用 CPU,而不管 TPU 是否可用。
- 弃用
- JAX v0.3.8 中弃用的几个测试工具现已从
jax.test_util
中移除。
jaxlib 0.3.22(2022 年 10 月 11 日)
jax 0.3.21(2022 年 9 月 30 日)
- GitHub 提交记录。
- 更改
- 持久化编译缓存现在在出错时会发出警告而不是抛出异常(#12582),所以如果缓存出现问题,程序可以继续执行。设置
JAX_RAISE_PERSISTENT_CACHE_ERRORS=true
可以恢复此行为。
jax 0.3.20(2022 年 9 月 28 日)
- Bug 修复:
- 添加了上一个发布版本中缺失的
.pyi
文件(#12536)。 - 修复了
jax
0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。 - 修复了
setup.py
注释中pip
的错误网址(#12528)。
jaxlib 0.3.20(2022 年 9 月 28 日)
- GitHub 提交记录。
- Bug 修复
- 修复通过
jax_cuda_visible_devices
在分布式作业中限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成非常重要(#12533)。
jax 0.3.19(2022 年 9 月 27 日)
- GitHub 提交记录。
- 需要的 jaxlib 版本修复。
jax 0.3.18(2022 年 9 月 26 日)
- GitHub 提交记录。
- 更改
- 提前编译和编译功能(在#7733中跟踪)是稳定和公开的。查看概述和
jax.stages
的 API 文档。 - 引入了
jax.Array
,用于 JAX 中数组类型的isinstance
检查和类型注释。请注意,这包括了对jax.numpy.ndarray
在 JAX 内部对象中如何工作的一些微妙更改,因为jax.numpy.ndarray
现在是jax.Array
的简单别名。
- 破坏性变更
jax._src
不再导入公共jax
命名空间。这可能会打破使用 JAX 内部功能的用户。- 已删除
jax.soft_pmap
。请改用pjit
或xmap
。jax.soft_pmap
未记录文档。如果有文档记录,将提供弃用期。
jax 0.3.17(2022 年 8 月 31 日)
- GitHub 提交记录。
- 错误修复
- 修复了
lax.pow
的梯度在指数为零时的特殊情况问题(#12041)
- 破坏性变更
jax.checkpoint()
,又称jax.remat()
,不再支持concrete
选项,遵循前一个版本的弃用;请参阅JEP 11830。
- 变更
- 添加了
jax.pure_callback()
,允许从编译函数(例如用jax.jit
或jax.pmap
装饰的函数)调用纯 Python 函数。
- 弃用:
- 已移除不推荐使用的
DeviceArray.tile()
方法。使用jax.numpy.tile()
代替(#11944)。 - 已弃用
DeviceArray.to_py()
。请改用np.asarray(x)
。
jax 0.3.16
- GitHub 提交记录。
- 破坏性变更
- 支持 NumPy 1.19 已被移除,根据弃用政策。请升级到 NumPy 1.20 或更新版本。
- 变更
- 添加了
jax.debug
,包括用于运行时值调试的实用程序,如jax.debug.print()
和jax.debug.breakpoint()
。 - 添加了用于运行时值调试的新文档
- 弃用
- 移除了
jax.mask()
和jax.shapecheck()
API。详见#11557。 - 移除了
jax.experimental.loops
。可查看#10278获取替代 API。 jax.tree_util.tree_multimap()
已移除。自 JAX 版本 0.3.5 起已被弃用,jax.tree_util.tree_map()
是直接替换。- 删除了
jax.experimental.stax
;它长期以来一直是jax.example_libraries.stax
的弃用别名。 - 移除了
jax.experimental.optimizers
;它长期以来一直是jax.example_libraries.optimizers
的弃用别名。 jax.checkpoint()
,又称jax.remat()
,有了一个新的默认实现,意味着旧的实现已被弃用;请参阅JEP 11830。
JAX 中文文档(十六)(4)https://developer.aliyun.com/article/1559730