JAX 中文文档(十六)(3)

简介: JAX 中文文档(十六)

JAX 中文文档(十六)(2)https://developer.aliyun.com/article/1559727

jaxlib 0.4.14(2023 年 7 月 27 日)

  • 弃用

jax 0.4.13(2023 年 6 月 22 日)

  • 更改
  • jax.jit 现在允许将 None 传递给 in_shardingsout_shardings。语义如下:
  • 对于 in_shardings,JAX 将其标记为复制,但这种行为可能会在将来更改。
  • 对于 out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
  • jax.experimental.pjit.pjit 也允许将 None 传递给 in_shardingsout_shardings。语义如下:
  • 如果未提供网格上下文管理器,则 JAX 可自由选择所需的分片方式。
  • 对于 in_shardings,JAX 将其标记为复制,但这种行为可能会在将来更改。
  • 对于 out_shardings,我们将依赖于 XLA GSPMD 分区器来确定输出的分片。
  • 如果提供了网格上下文管理器,None 将意味着该值将在网格的所有设备上复制。
  • Executable.cost_analysis() 在 Cloud TPU 上可用
  • 如果正在使用非允许的 jaxlib 插件,则添加了警告。
  • 添加了 jax.tree_util.tree_leaves_with_path
  • None 不是 jax.experimental.multihost_utils.host_local_array_to_global_arrayjax.experimental.multihost_utils.global_array_to_host_local_array 的有效输入。如果您希望复制您的输入,请使用 jax.sharding.PartitionSpec()
  • Bug 修复
  • 在 CUDA 12 发布中修复了错误的轮子名称(#16362);正确的轮子名称为 cudnn89 而不是 cudnn88
  • 弃用
  • jax.experimental.jax2tf.convert()native_serialization_strict_checks 参数已被弃用,推荐使用新的 native_serializaation_disabled_checks#16347)。

jaxlib 0.4.13(2023 年 6 月 22 日)

  • 更改
  • 将 Windows 仅 CPU 轮子添加到 jaxlib Pypi 发布中。
  • Bug 修复
  • __cuda_array_interface__ 在之前的 jaxlib 版本中出现问题,现已修复(#16440)。
  • 并行 CUDA 内核跟踪现在默认启用于 NVIDIA GPU。

jax 0.4.12(2023 年 6 月 8 日)

  • 更改
  • 弃用
  • jax.abstract_arrays及其内容已被弃用。请参阅:mod:jax.core中的相关功能。
  • jax.numpy.alltrue:使用jax.numpy.all。这遵循了 NumPy 版本 1.25.0 中numpy.alltrue的弃用。
  • jax.numpy.sometrue:使用jax.numpy.any。这遵循了 NumPy 版本 1.25.0 中numpy.sometrue的弃用。
  • jax.numpy.product:使用jax.numpy.prod。这遵循了 NumPy 版本 1.25.0 中numpy.product的弃用。
  • jax.numpy.cumproduct:使用jax.numpy.cumprod。这遵循了 NumPy 版本 1.25.0 中numpy.cumproduct的弃用。
  • jax.sharding.OpShardingSharding已被移除,因为它已经弃用了 3 个月。

jaxlib 0.4.12 (2023 年 6 月 8 日)

  • 变更
  • 包含了 Hopper(SM 版本 9.0+)GPU 的 PTX/SASS。之前的 jaxlib 版本应该可以在 Hopper 上工作,但第一次执行 JAX 操作时可能会有较长的 JIT 编译延迟。
  • Bug 修复
  • 修复了在 Python 3.11 下 JAX 生成的 Python 回溯中源代码行信息不正确的问题。
  • 修复了在 JAX 生成的 Python 回溯的帧中打印本地变量时崩溃的问题(#16027)。

jax 0.4.11 (2023 年 5 月 31 日)

  • 弃用
  • 根据 API 兼容性政策,在 3 个月的弃用期后,已移除以下 API:
  • jax.experimental.PartitionSpec:使用jax.sharding.PartitionSpec
  • jax.experimental.maps.Mesh:使用jax.sharding.Mesh
  • jax.experimental.pjit.NamedSharding:使用jax.sharding.NamedSharding
  • jax.experimental.pjit.PartitionSpec:使用jax.sharding.PartitionSpec
  • jax.experimental.pjit.FROM_GDA。请将分片的jax.Array对象作为输入传递,并删除pjit的可选in_shardings参数。
  • jax.interpreters.pxla.PartitionSpec:使用jax.sharding.PartitionSpec
  • jax.interpreters.pxla.Mesh:使用jax.sharding.Mesh
  • jax.interpreters.xla.Buffer:使用jax.Array
  • jax.interpreters.xla.Device:使用jax.Device
  • jax.interpreters.xla.DeviceArray:使用jax.Array
  • jax.interpreters.xla.device_put:使用jax.device_put
  • jax.interpreters.xla.xla_call_p:使用jax.experimental.pjit.pjit_p
  • with_sharding_constraintaxis_resources参数已被移除。请改用shardings

jaxlib 0.4.11 (2023 年 5 月 31 日)

  • 变更
  • Device添加了memory_stats()方法。如果支持,它将返回一个包含字符串统计名称和整数值的字典,例如"bytes_in_use",如果平台不支持内存统计,则返回 None。具体的统计数据可能因平台而异。目前仅在 Cloud TPU 上实现。
  • 重新添加了对 CPU 设备上 Python 缓冲协议(memoryview)的支持。

jax 0.4.10 (2023 年 5 月 11 日)

jaxlib 0.4.10 (2023 年 5 月 11 日)

  • 变更
  • 修复了阻止上一个版本在 Mac M1 上运行的'apple-m1' is not a recognized processor for this target (ignoring processor)问题。

jax 0.4.9 (2023 年 5 月 9 日)

  • 变更
  • experimental_cpp_jitexperimental_cpp_pjitexperimental_cpp_pmap标志已被移除。它们现在始终开启。
  • TPU 上奇异值分解(SVD)的准确性已经提高(需要 jaxlib 0.4.9)。
  • 废弃功能
  • jax.experimental.gda_serialization已废弃,并已重命名为jax.experimental.array_serialization。请更改您的导入以使用jax.experimental.array_serialization
  • pjitin_axis_resourcesout_axis_resources参数已废弃。请分别使用in_shardingsout_shardings
  • 函数jax.numpy.msort已被移除。自 JAX v0.4.1 起已被废弃。请使用jnp.sort(a, axis=0)代替。
  • in_partsout_parts参数已从jax.xla_computation中移除,因为它们只与sharded_jit一起使用,并且sharded_jit已不再使用。
  • 自从很久以来未被使用,instantiate_const_outputs参数已从jax.xla_computation中移除。

jaxlib 0.4.9(2023 年 5 月 9 日)

jax 0.4.8(2023 年 3 月 29 日)

  • 破坏性变更
  • Cloud TPU 运行时的一个重要组件已升级。这使得以下新功能在 Cloud TPU 上可用:
  • jax.debug.print()jax.debug.callback()jax.debug.breakpoint()现在在 Cloud TPU 上可用。
  • 自动 TPU 内存碎片整理
  • 在新的运行时组件上,不再支持jax.experimental.host_callback()在 Cloud TPU 上的使用。如果新的jax.debug API 不能满足您的需求,请在JAX 问题跟踪器上提出问题。
    旧的运行时组件将通过设置环境变量JAX_USE_PJRT_C_API_ON_TPU=false至少在接下来的三个月内可用。如果您发现需要出于任何原因禁用新的运行时,请在JAX 问题跟踪器上告知我们。
  • 变更
  • 最低 jaxlib 版本已从 0.4.6 提升至 0.4.7。
  • 废弃功能
  • 支持 CUDA 11.4 已被移除。JAX GPU 版本仅支持 CUDA 11.8 和 CUDA 12。如果使用旧版 CUDA 构建 jaxlib 可能会正常工作。
  • pmapglobal_arg_shapes参数仅适用于sharded_jit,已从pmap中移除。请迁移到pjit并从pmap中移除global_arg_shapes

jax 0.4.7(2023 年 3 月 27 日)

  • 变更
  • 根据 https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration,不再支持禁用jax.config.jax_array
  • 不再支持禁用jax.config.jax_jit_pjit_api_merge
  • jax.experimental.jax2tf.convert()现在支持native_serialization参数,使用 JAX 的本机降级到 StableHLO 以获取整个 JAX 函数的 StableHLO 模块,而不是将每个 JAX 原语降级到 TensorFlow 操作。这简化了内部操作,并增加了您序列化内容与 JAX 本机语义匹配的信心。详见文档。作为这一变更的一部分,配置标志--jax2tf_default_experimental_native_lowering已重命名为--jax2tf_native_serialization
  • JAX 现在依赖于 ml_dtypes,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。
  • JAX 现在要求使用 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。
  • 弃用信息
  • 类型 jax.numpy.DeviceArray 已弃用。请改用 jax.Array,它是其别名。
  • 类型 jax.interpreters.pxla.ShardedDeviceArray 已弃用。请改用 jax.Array
  • 通过位置传递额外参数给 jax.numpy.ndarray.at() 已被弃用。例如,不要使用 x.at[i].get(True),而是使用 x.at[i].get(indices_are_sorted=True)
  • jax.interpreters.xla.device_put 已被弃用。请使用 jax.device_put
  • jax.interpreters.pxla.device_put 已被弃用。请使用 jax.device_put
  • jax.experimental.pjit.FROM_GDA 已被弃用。请将分片的 jax.Arrays 作为输入,并移除 pjit 中的 in_shardings 参数,因为它是可选的。

jaxlib 0.4.7(2023 年 3 月 27 日)

变更:

  • jaxlib 现在依赖于 ml_dtypes,其中包含类似于 bfloat16 的 NumPy 类型的定义。这些定义以前是 JAX 的内部部分,但已拆分为一个单独的包,以便与其他项目共享。

jax 0.4.6(2023 年 3 月 9 日)

  • 变更
  • jax.tree_util 现在包含一组允许用户为其自定义 pytree 节点定义键的 API。
  • tree_flatten_with_path 可以展平树并返回每个叶子及其键路径。
  • tree_map_with_path 可以映射一个接受键路径作为参数的函数。
  • register_pytree_with_keys 用于注册自定义 pytree 节点中键路径和叶子的外观。
  • keystr 用于漂亮地打印键路径。
  • jax2tf.call_tf() 现在有一个新参数 output_shape_dtype(默认为 None),可用于声明结果的输出形状和类型。这使得 jax2tf.call_tf() 能够在形状多态性存在的情况下工作。(#14734)
  • 弃用信息
  • jax.tree_util 中的旧键路径 API 已被弃用,并将在 2023 年 3 月 10 日后的 3 个月内移除:
  • register_keypaths:请使用 jax.tree_util.register_pytree_with_keys() 替代。
  • AttributeKeyPathEntry:请改用 GetAttrKey
  • GetitemKeyPathEntry:请改用 SequenceKeyDictKey

jaxlib 0.4.6(2023 年 3 月 9 日)

jax 0.4.5(2023 年 3 月 2 日)

  • 弃用信息
  • jax.sharding.OpShardingSharding 已重命名为 jax.sharding.GSPMDShardingjax.sharding.OpShardingSharding 将在 2023 年 2 月 17 日后的 3 个月内移除。
  • 下列 jax.Array 方法已被弃用,并将在 2023 年 2 月 23 日后的 3 个月内移除:
  • jax.Array.broadcast:请使用 jax.lax.broadcast() 替代。
  • jax.Array.broadcast_in_dim:请使用 jax.lax.broadcast_in_dim() 替代。
  • jax.Array.split:请使用 jax.numpy.split() 替代。

jax 0.4.4(2023 年 2 月 16 日)

  • 变更
  • jitpjit 的实现已合并。合并 jit 和 pjit 改变了 JAX 的内部实现,但不影响 JAX 的公共 API。之前,jit 是一种最终风格的原语。最终风格意味着尽可能延迟创建 jaxpr 并将变换堆叠在一起。随着 jit-pjit 实现的合并,jit 变成了一种初始风格的原语,这意味着我们尽早追踪到 jaxpr。更多信息请参见 autodidax 中的这一部分。转移到初始风格应该简化 JAX 的内部实现,并使得动态形状等功能的开发更加容易。你只能通过环境变量来禁用它,即 os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。由于它影响到 JAX 的导入时机,因此必须通过环境变量禁用它,在导入 jax 之前就需要禁用它。
  • with_sharding_constraintaxis_resources 参数已弃用。请改用 shardings。如果你将其作为参数使用,则无需更改。如果你将其作为关键字参数使用,请改用 shardingsaxis_resources 将在 2023 年 2 月 13 日后的 3 个月内删除。
  • 添加了 jax.typing 模块,用于 JAX 函数的类型注解工具。
  • 下列名称已被弃用:
  • jax.xla.Devicejax.interpreters.xla.Device: 使用 jax.Device
  • jax.experimental.maps.Mesh. 使用 jax.sharding.Mesh 替代。
  • jax.experimental.pjit.NamedSharding: 使用 jax.sharding.NamedSharding
  • jax.experimental.pjit.PartitionSpec: 使用 jax.sharding.PartitionSpec
  • jax.interpreters.pxla.Mesh: 使用 jax.sharding.Mesh
  • jax.interpreters.pxla.PartitionSpec: 使用 jax.sharding.PartitionSpec
  • Breaking Changes
  • jax.numpy.sum 等的 initial 参数现在要求是一个标量,与对应的 NumPy API 保持一致。以前的行为是对非标量 initial 值进行广播,这是一个意外的实现细节(#14446)。

jaxlib 0.4.4(2023 年 2 月 16 日)

  • Breaking changes
  • 默认的 jaxlib 构建中已移除对 NVIDIA Kepler 系列 GPU 的支持。如果需要 Kepler 支持,可以通过使用 Kepler 支持的源码构建 jaxlib(通过 build.py--cuda_compute_capabilities=sm_35 选项),不过请注意 CUDA 12 已完全停止支持 Kepler GPU。

jax 0.4.3(2023 年 2 月 8 日)

  • Breaking changes
  • 删除了 jax.scipy.linalg.polar_unitary(),这是一个已弃用的 JAX 扩展到 scipy API 的函数。请改用 jax.scipy.linalg.polar()
  • Changes
  • 添加了 jax.scipy.stats.rankdata()

jaxlib 0.4.3(2023 年 2 月 8 日)

  • jax.Array 现在具有非阻塞的 is_ready() 方法,如果数组已准备就绪则返回 True(参见 jax.block_until_ready())。

jax 0.4.2(2023 年 1 月 24 日)

  • Breaking changes
  • 删除了 jax.experimental.callback
  • 在存在 jax2tf 形状多态性的情况下,对带有维度的操作进行了泛化处理,通过将符号维度转换为 JAX 数组来在更多场景下工作。现在,涉及符号维度和 np.ndarray 的操作在结果用作形状值时可能会引发错误(#14106)。
  • 现在,jaxpr 对象在设置属性时会引发错误,以避免问题变异(#14102
  • 变更
  • jax2tf.call_tf() 现在有一个新参数 has_side_effects(默认为 True),可用于声明实例是否可以被 JAX 优化(如死代码消除)删除或复制(#13980)。
  • 为了支持 jax2tf 形状多态性的 floordivmod,我们增加了更多支持。之前,存在符号维度时某些除法操作会导致错误(#14108)。

jaxlib 0.4.2(2023 年 1 月 24 日)

  • 变更
  • 设置 JAX_USE_PJRT_C_API_ON_TPU=1 可启用新的 Cloud TPU 运行时,具备自动设备内存碎片整理功能。

jax 0.4.1(2022 年 12 月 13 日)

  • 变更
  • 根据 JAX 的 Python 和 NumPy 版本支持政策,不再支持 Python 3.7。
  • 我们引入了 jax.Array,它是 JAX 中的统一数组类型,涵盖了 DeviceArrayShardedDeviceArrayGlobalDeviceArray 类型。jax.Array 类型有助于使并行成为 JAX 的核心特性,简化和统一 JAX 内部结构,并允许我们统一 jitpjitjax.Array 已在 JAX 0.4 中默认启用,并对 pjit API 进行了一些破坏性更改。jax.Array 迁移指南 可帮助您将代码库迁移到 jax.Array。您还可以查看Distributed arrays and automatic parallelization 教程,以理解新概念。
  • PartitionSpecMesh 现在不再处于实验阶段。新的 API 端点是 jax.sharding.PartitionSpecjax.sharding.Meshjax.experimental.maps.Meshjax.experimental.PartitionSpec 已被弃用,并将在三个月内移除。
  • with_sharding_constraint 的新公共端点是 jax.lax.with_sharding_constraint
  • 如果与 jax.config 一起使用 ABSL 标志,那么在最初从 ABSL 标志填充 JAX 配置选项后,就不再读取或写入 ABSL 标志值。此更改改进了读取 jax.config 选项的性能,这些选项在 JAX 中广泛使用。
  • jax2tf.call_tf 函数现在使用与嵌入 JAX 计算相同平台的第一个 TF 设备进行 TF 降级。以前,它使用的是 JAX 默认后端的第 0 个设备。
  • 现在,一些 jax.numpy 函数的参数已标记为仅限位置参数,与 NumPy 匹配。
  • jnp.msort现已废弃,遵循 numpy 1.24 中np.msort的废弃。它将在未来的版本中移除,符合 API 兼容性策略。可以用jnp.sort(a, axis=0)替换。

jaxlib 0.4.1 (2022 年 12 月 13 日)

  • 变更
  • 支持 Python 3.7 已被放弃,符合 JAX 的 Python 和 NumPy 版本支持政策。
  • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX的行为已更改,现在分配总 GPU 内存的 XX%来预分配,而不是以前使用当前可用 GPU 内存来计算预分配。有关更多详情,请参阅GPU memory allocation
  • 废弃的方法.block_host_until_ready()已被移除。请改用.block_until_ready()

jax 0.4.0 (2022 年 12 月 12 日)

  • 此版本已被撤回。

jaxlib 0.4.0 (2022 年 12 月 12 日)

  • 此版本已被撤回。

jax 0.3.25 (2022 年 11 月 15 日)

  • 变更
  • jax.numpy.linalg.pinv()现在支持hermitian选项。
  • jax.scipy.linalg.hessenberg()现在仅在 CPU 上支持。需要 jaxlib > 0.3.24。
  • 新函数jax.lax.linalg.hessenberg()jax.lax.linalg.tridiagonal()jax.lax.linalg.householder_product()已添加。Householder 约简目前仅支持 CPU,三对角约简支持 CPU 和 GPU。
  • 现在更经济地计算非方阵的svdjax.numpy.linalg.pinv的梯度。
  • 突破性变更
  • 删除了jax_experimental_name_stack配置选项。
  • 将字符串axis_names参数转换为jax.experimental.maps.Mesh构造函数的单例元组,而不是将字符串解包为字符轴名称序列。

jaxlib 0.3.25 (2022 年 11 月 15 日)

  • 变更
  • 添加了对 CPU 和 GPU 上三对角约简的支持。
  • 添加了对 CPU 上上 Hessenberg 约简的支持。
  • Bug 修复
  • 修复了一个 bug,导致 JAX 捕获的回溯中的帧被错误地映射到 Python 3.10+下的源行。

jax 0.3.24 (2022 年 11 月 4 日)

  • 变更
  • JAX 导入速度应更快。现在我们懒惰地导入 scipy,这在 JAX 的导入时间中占据了相当大的部分。
  • 设置环境变量JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N可以用于限制写入持久缓存的缓存条目数量。默认情况下,编译时间超过 1 秒的计算将被缓存。
  • 添加了 jax.scipy.stats.mode()
  • 如果在 TPU 上未指定顺序,则pmap的默认设备顺序现在与单进程作业的jax.devices()匹配。以前两种排序不同,可能导致不必要的复制或内存不足错误。要求排序一致简化了问题。
  • 突破性变更
  • jax.numpy.gradient()现在像jax.numpy中的大多数其他函数一样,禁止传递列表或元组以替代数组(#12958)。
  • jax.numpy.linalgjax.numpy.fft中的函数现在统一要求输入为数组形式:即不能使用列表和元组代替数组。部分属于#7737
  • 弃用
  • jax.sharding.MeshPspecSharding已重命名为jax.sharding.NamedShardingjax.sharding.MeshPspecSharding名称将在 3 个月内删除。

jaxlib 0.3.24(2022 年 11 月 4 日)

  • 更改
  • 现在在 CPU 上可以使用缓冲器捐赠。这可能会破坏在 CPU 上标记缓冲区进行捐赠但依赖捐赠未实现的代码。

jax 0.3.23(2022 年 10 月 12 日)

  • 更改
  • 更新 Colab TPU 驱动程序版本以支持新的 jaxlib 发布。

jax 0.3.22(2022 年 10 月 11 日)

  • 更改
  • 在 TPU 初始化中添加JAX_PLATFORMS=tpu,cpu作为默认设置,因此如果无法初始化 TPU,JAX 将引发错误而不是回退到 CPU。设置JAX_PLATFORMS=''以覆盖此行为并自动选择可用的后端(原始默认),或设置JAX_PLATFORMS=cpu以始终使用 CPU,而不管 TPU 是否可用。
  • 弃用
  • JAX v0.3.8 中弃用的几个测试工具现已从jax.test_util中移除。

jaxlib 0.3.22(2022 年 10 月 11 日)

jax 0.3.21(2022 年 9 月 30 日)

  • 持久化编译缓存现在在出错时会发出警告而不是抛出异常(#12582),所以如果缓存出现问题,程序可以继续执行。设置JAX_RAISE_PERSISTENT_CACHE_ERRORS=true可以恢复此行为。

jax 0.3.20(2022 年 9 月 28 日)

  • Bug 修复:
  • 添加了上一个发布版本中缺失的.pyi文件(#12536)。
  • 修复了jax 0.3.19 与其固定的 libtpu 版本之间的不兼容性(#12550)。需要 jaxlib 0.3.20。
  • 修复了setup.py注释中pip的错误网址(#12528)。

jaxlib 0.3.20(2022 年 9 月 28 日)

  • 修复通过jax_cuda_visible_devices在分布式作业中限制可见 CUDA 设备的支持。此功能对于 GPU 上的 JAX/SLURM 集成非常重要(#12533)。

jax 0.3.19(2022 年 9 月 27 日)

jax 0.3.18(2022 年 9 月 26 日)

  • 提前编译和编译功能(在#7733中跟踪)是稳定和公开的。查看概述jax.stages的 API 文档。
  • 引入了jax.Array,用于 JAX 中数组类型的isinstance检查和类型注释。请注意,这包括了对jax.numpy.ndarray在 JAX 内部对象中如何工作的一些微妙更改,因为jax.numpy.ndarray现在是jax.Array的简单别名。
  • 破坏性变更
  • jax._src不再导入公共jax命名空间。这可能会打破使用 JAX 内部功能的用户。
  • 已删除jax.soft_pmap。请改用pjitxmapjax.soft_pmap未记录文档。如果有文档记录,将提供弃用期。

jax 0.3.17(2022 年 8 月 31 日)

  • 修复了lax.pow的梯度在指数为零时的特殊情况问题(#12041
  • 破坏性变更
  • jax.checkpoint(),又称jax.remat(),不再支持concrete选项,遵循前一个版本的弃用;请参阅JEP 11830
  • 变更
  • 添加了jax.pure_callback(),允许从编译函数(例如用jax.jitjax.pmap装饰的函数)调用纯 Python 函数。
  • 弃用:
  • 已移除不推荐使用的DeviceArray.tile()方法。使用jax.numpy.tile()代替(#11944)。
  • 已弃用DeviceArray.to_py()。请改用np.asarray(x)

jax 0.3.16

  • 支持 NumPy 1.19 已被移除,根据弃用政策。请升级到 NumPy 1.20 或更新版本。
  • 变更
  • 添加了jax.debug,包括用于运行时值调试的实用程序,如jax.debug.print()jax.debug.breakpoint()
  • 添加了用于运行时值调试的新文档
  • 弃用
  • 移除了jax.mask()jax.shapecheck() API。详见#11557
  • 移除了jax.experimental.loops。可查看#10278获取替代 API。
  • jax.tree_util.tree_multimap()已移除。自 JAX 版本 0.3.5 起已被弃用,jax.tree_util.tree_map()是直接替换。
  • 删除了jax.experimental.stax;它长期以来一直是jax.example_libraries.stax的弃用别名。
  • 移除了jax.experimental.optimizers;它长期以来一直是jax.example_libraries.optimizers的弃用别名。
  • jax.checkpoint(),又称jax.remat(),有了一个新的默认实现,意味着旧的实现已被弃用;请参阅JEP 11830


JAX 中文文档(十六)(4)https://developer.aliyun.com/article/1559730

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
2天前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
12 2
|
2天前
|
并行计算 算法框架/工具 异构计算
JAX 中文文档(十六)(5)
JAX 中文文档(十六)
13 2
|
2天前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
14 3
|
2天前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
13 3
|
2天前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
11 1
|
2天前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
10 1
|
2天前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
9 2
|
2天前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
7 1
|
2天前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
9 0
|
2天前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(2)
JAX 中文文档(十五)
13 0