JAX 中文文档(十六)(4)

简介: JAX 中文文档(十六)

JAX 中文文档(十六)(3)https://developer.aliyun.com/article/1559729


jax 0.3.15(2022 年 7 月 22 日)

  • jax.test_util 中已移除 JaxTestCaseJaxTestLoader 类,自 v0.3.1 起已弃用(#11248)。
  • 添加了 jax.scipy.gaussian_kde#11237)。
  • JAX 数组与内置集合(dictlistsettuple)之间的二元操作现在在所有情况下都会引发 TypeError。以前的某些情况(特别是相等性和不等式)会返回与 NumPy 中类似操作不一致的布尔标量(#11234)。
  • 几个作为顶级 JAX 包导入的 jax.tree_util 例程现已弃用,并将根据 API 兼容性政策在未来的 JAX 发布版本中移除。
  • jax.treedef_is_leaf() 已弃用,推荐使用 jax.tree_util.treedef_is_leaf()
  • jax.tree_flatten() 已弃用,推荐使用 jax.tree_util.tree_flatten()
  • jax.tree_leaves() 已弃用,推荐使用 jax.tree_util.tree_leaves()
  • jax.tree_structure() 已弃用,推荐使用 jax.tree_util.tree_structure()
  • jax.tree_transpose() 已弃用,推荐使用 jax.tree_util.tree_transpose()
  • jax.tree_unflatten() 已弃用,推荐使用 jax.tree_util.tree_unflatten()
  • jax.scipy.linalg.solve()sym_pos 参数已弃用,推荐使用 assume_a='pos',遵循 scipy.linalg.solve() 中类似的弃用。

jaxlib 0.3.15(2022 年 7 月 22 日)

jax 0.3.14(2022 年 6 月 27 日)

  • jax.experimental.compilation_cache.initialize_cache() 现在不再支持 max_cache_size_ bytes,并且不会将其作为输入。
  • 当平台初始化失败时,JAX_PLATFORMS 现在会引发异常。
  • 变更
  • 解决了与 NumPy 1.23 的兼容性问题。
  • jax.numpy.linalg.slogdet() 现在接受一个可选的 method 参数,允许选择基于 LU 分解或基于 QR 分解的实现。
  • jax.numpy.linalg.qr() 现在支持 mode="raw"
  • 在对 JAX 数组使用 picklecopy.copycopy.deepcopy 时,现在支持更完整的支持(#10659)。特别是:
  • 当对 DeviceArray 使用 pickledeepcopy 时,以前返回 np.ndarray 对象,现在返回 DeviceArray 对象。对于 deepcopy,复制的数组位于与原始数组相同的设备上。对于 pickle,反序列化的数组将位于默认设备上。
  • 在函数转换(即跟踪代码)内部,deepcopycopy 以前是空操作。现在它们使用与 DeviceArray.copy() 相同的机制。
  • 对跟踪数组进行 pickle 操作现在会导致显式的 ConcretizationTypeError
  • 在 TPU 上,奇异值分解(SVD)和对称/Hermitian 特征分解的实现应显著更快,特别是对于超过 1000x1000 大小的矩阵。现在都使用了谱分裂与征算法进行特征分解(QDWH-eig)。
  • jax.numpy.ldexp() 现在不再将所有输入默认提升为 float64,而是对于 int32 或更小的整数输入,提升为 float32 (#10921)。
  • 添加了一个 create_perfetto_link 选项到 jax.profiler.start_trace()jax.profiler.start_trace()。使用时,分析器将生成一个链接到 Perfetto UI 以查看跟踪信息。
  • 更改了 jax.profiler.start_server(...)() 的语义,将 keepalive 全局存储,而不再要求用户保留引用。
  • 添加了 jax.random.generalized_normal()
  • 添加了 jax.random.ball()
  • 添加了 jax.default_device()
  • 添加了一个 python -m jax.collect_profile 脚本,手动捕获程序跟踪,作为 TensorBoard UI 的替代方法。
  • 添加了一个 jax.named_scope 上下文管理器,向 Python 程序添加分析器元数据(类似于 jax.named_call)。
  • 在 scatter-update 操作(即 :attr:jax.numpy.ndarray.at)中,不安全的隐式 dtype 转换已弃用,现在会产生 FutureWarning。在将来的版本中,这将变成一个错误。一个不安全的隐式转换的例子是 jnp.zeros(4, dtype=int).at[0].set(1.5),其中 1.5 之前会被静默截断为 1
  • jax.experimental.compilation_cache.initialize_cache() 现在支持 gcs 存储桶路径作为输入。
  • 添加了 jax.scipy.stats.gennorm()
  • jax.numpy.roots() 现在在 strip_zeros=False 时,在系数有前导零时行为更佳 (#11215)。

jaxlib 0.3.14(2022 年 6 月 27 日)。

  • x86-64 Mac wheels 现在要求 Mac OS 10.14(Mojave)或更新版本。Mac OS 10.14 发布于 2018 年,因此这不应该是一个非常繁重的要求。
  • 捆绑的 NCCL 版本更新到 2.12.12,修复了一些死锁问题。
  • Python flatbuffers 包不再是 jaxlib 的依赖项。

jax 0.3.13(2022 年 5 月 16 日)。

jax 0.3.12(2022 年 5 月 15 日)。

jax 0.3.11(2022 年 5 月 15 日)。

  • jax.lax.eigh() 现在接受一个可选的 sort_eigenvalues 参数,允许用户在 TPU 上选择不排序特征值。
  • 弃用:
  • jax.lax.linalg 中的函数现在要求非数组参数必须作为关键字参数传递。为了向后兼容,将关键字参数作为位置参数传递将会得到警告,但在未来的 JAX 发布中,将会导致失败。大多数用户应该优先考虑使用 jax.numpy.linalg
  • jax.scipy.linalg.polar_unitary(),这是 JAX 对 scipy API 的扩展,已被弃用。请改用 jax.scipy.linalg.polar()

jax 0.3.10 (2022 年 5 月 3 日)

jaxlib 0.3.10 (2022 年 5 月 3 日)

  • TF 提交记录 修复了 MHLO 规范化器中的问题,该问题导致某些程序的常量折叠花费很长时间或崩溃。

jax 0.3.9 (2022 年 5 月 2 日)

  • 增加了对 GlobalDeviceArray 的完全异步检查点支持。

jax 0.3.8 (2022 年 4 月 29 日)

  • 在 TPU 上,jax.numpy.linalg.svd() 现在使用 qdwh-svd 求解器。
  • 在 TPU 上,jax.numpy.linalg.cond() 现在接受复数输入。
  • 在 TPU 上,jax.numpy.linalg.pinv() 现在接受复数输入。
  • 在 TPU 上,jax.numpy.linalg.matrix_rank() 现在接受复数输入。
  • 已添加 jax.scipy.cluster.vq.vq()
  • jax.experimental.maps.mesh 已删除。请使用 jax.experimental.maps.Mesh。请参阅 此处 获取更多信息。
  • mode='r' 时,jax.scipy.linalg.qr() 现在返回一个长度为 1 的元组,而不是原始数组,以匹配 scipy.linalg.qr 的行为(#10452
  • jax.numpy.take_along_axis() 现在接受一个可选的 mode 参数,用于指定超出边界索引的行为。默认情况下,超出边界的索引会返回无效值(例如 NaN)。在 JAX 的早期版本中,无效的索引会被夹在范围内。可以通过传递 mode="clip" 恢复先前的行为。
  • jax.numpy.take() 现在默认为 mode="fill",这会对超出索引范围的位置返回无效值(例如 NaN)。
  • 散点操作,例如 x.at[...].set(...),现在具有 "drop" 语义。这对散点操作本身没有影响,但这意味着在进行微分时,散点的梯度对超出边界的索引的余切为零。以前超出边界的索引在梯度中被夹在范围内,这在数学上是不正确的。
  • jax.numpy.take_along_axis() 现在如果其索引不是整数类型将会引发 TypeError,与 numpy.take_along_axis() 的行为一致。先前非整数索引会被静默转换为整数。
  • jax.numpy.ravel_multi_index() 现在如果其 dims 参数不是整数类型将会引发 TypeError,与 numpy.ravel_multi_index() 的行为一致。先前非整数 dims 参数会被静默转换为整数。
  • jax.numpy.split() 现在如果其 axis 参数不是整数类型将会引发 TypeError,与 numpy.split() 的行为一致。先前非整数 axis 参数会被静默转换为整数。
  • jax.numpy.indices() 现在如果其维度不是整数类型将会引发 TypeError,与 numpy.indices() 的行为一致。先前非整数维度会被静默转换为整数。
  • jax.numpy.diag() 现在如果其 k 参数不是整数类型将会引发 TypeError,与 numpy.diag() 的行为一致。先前非整数 k 参数会被静默转换为整数。
  • 添加了 jax.random.orthogonal()
  • 已过时:
  • 许多 jax.test_util 中可用的函数和对象现已过时,并将在导入时引发警告。包括 cases_from_listcheck_closecheck_eqdevice_under_testformat_shape_dtype_stringrand_uniformskip_on_deviceswith_configxla_bridge_default_tolerance#10389)。这些以及先前过时的 JaxTestCaseJaxTestLoaderBufferDonationTestCase 将在未来的 JAX 发布中移除。大多数这些实用程序可以通过调用标准的 Python 和 NumPy 测试实用程序来替换,如 unittestabsl.testingnumpy.testing 等。可以通过公共 API(例如 jax.devices())来替换 JAX 特定的功能,如设备检查。许多已过时的实用程序仍然存在于 jax._src.test_util 中,但这些不是公共 API,因此可能在未来的发布中更改或移除,而不另行通知。

jax 0.3.7(2022 年 4 月 15 日)

  • 修复了当传递给 jax.numpy.take_along_axis() 的索引广播时的性能问题(#10281)。
  • jax.scipy.special.expit()jax.scipy.special.logit() 现在要求其参数为标量或 JAX 数组。它们现在还将整数参数提升为浮点数。
  • DeviceArray.tile() 方法已弃用,因为 numpy 数组没有 tile() 方法。作为替代,请使用 jax.numpy.tile()#10266)。

jaxlib 0.3.7(2022 年 4 月 15 日)

  • 变更:
  • Linux 版本现在符合 manylinux2014 标准,而不是 manylinux2010

jax 0.3.6(2022 年 4 月 12 日)

  • 将 libtpu 轮子升级到修复初始化 TPU pod 时挂起的版本。修复了 #10218
  • 弃用:
  • jax.experimental.loops 将被弃用。参见 #10278 了解替代 API。

jax 0.3.5(2022 年 4 月 7 日)

  • 添加了 jax.random.loggamma() 并改进了对小参数值的 jax.random.beta()jax.random.dirichlet() 的行为(#9906)。
  • lax_numpy 私有子模块不再暴露在 jax.numpy 命名空间中(#10029)。
  • 添加了数组创建例程 jax.numpy.frombuffer()jax.numpy.fromfunction()jax.numpy.fromstring()#10049)。
  • DeviceArray.copy() 现在返回 DeviceArray 而不是 np.ndarray#10069
  • 添加了 jax.scipy.linalg.rsf2csf()
  • jax.experimental.sharded_jit 已被弃用,并将很快移除。
  • 弃用:
  • jax.nn.normalize() 将被弃用。请使用 jax.nn.standardize() 替代(#9899)。
  • jax.tree_util.tree_multimap() 已弃用。请使用 jax.tree_util.tree_map() 替代(#5746)。
  • jax.experimental.sharded_jit 已弃用。请使用 pjit 替代。

jaxlib 0.3.5(2022 年 4 月 7 日)

  • 修复了 bug
  • 修复了一个 bug,双精度复杂到实数 IRFFT 在 GPU 上会改变其输入缓冲区(#9946)。
  • 修复了复杂散布常量折叠错误(#10159

jax 0.3.4(2022 年 3 月 18 日)

jax 0.3.3(2022 年 3 月 17 日)

jax 0.3.2(2022 年 3 月 16 日)

  • 函数 jax.ops.index_updatejax.ops.index_add 在 0.2.22 中已弃用。请使用JAX 数组上的 .at 属性,例如,x.at[idx].set(y)
  • jax.experimental.ann.approx_*_k 移至 jax.lax。这些函数是 jax.lax.top_k 的优化替代品。
  • jax.numpy.broadcast_arrays()jax.numpy.broadcast_to() 现在要求标量或类数组输入,并在传递列表时将失败(部分 #7737)。
  • 标准的 jax[tpu] 安装现在可以与 Cloud TPU v4 VMs 一起使用。
  • pjit 现在支持在 CPU 上运行(除了之前的 TPU 和 GPU 支持)。

jaxlib 0.3.2 (2022 年 3 月 16 日)

  • 更改
  • XlaComputation.as_hlo_text() 现在支持通过传递布尔标志 print_large_constants=True 打印大常量。
  • 弃用:
  • JAX 数组上的 .block_host_until_ready() 方法已弃用。请改用 .block_until_ready()

jax 0.3.1 (2022 年 2 月 18 日)

  • jax.test_util.JaxTestCasejax.test_util.JaxTestLoader 现在已弃用。建议直接使用 parametrized.TestCase 进行替换。对于依赖于自定义断言(如 JaxTestCase.assertAllClose())的测试,请使用标准的 numpy 测试工具,如numpy.testing.assert_allclose(),它们直接与 JAX 数组一起工作(#9620)。
  • jax.test_util.JaxTestCase 现在默认设置 jax_numpy_rank_promotion='raise'#9562)。要恢复以前的行为,请使用新的 jax.test_util.with_config 装饰器:
@jtu.with_config(jax_numpy_rank_promotion='allow')
class MyTestCase(jtu.JaxTestCase):
  ... 
  • 添加了 jax.scipy.linalg.schur()jax.scipy.linalg.sqrtm()jax.scipy.signal.csd()jax.scipy.signal.stft()jax.scipy.signal.welch()

jax 0.3.0 (2022 年 2 月 10 日)

  • jax 版本已升级至 0.3.0. 请参阅设计文档以获取说明。

jaxlib 0.3.0 (2022 年 2 月 10 日)

  • 更改
  • 现在需要 Bazel 5.0.0 来构建 jaxlib。
  • jaxlib 版本已升级至 0.3.0. 请参阅设计文档以获取说明。

jax 0.2.28 (2022 年 2 月 1 日)

  • 如果未传递 dialect=jax.jit(f).lower(...).compiler_ir() 现在默认为 MHLO 方言。
  • jax.jit(f).lower(...).compiler_ir(dialect='mhlo') 现在返回 MLIR ir.Module 对象,而不是其字符串表示。

jaxlib 0.1.76 (2022 年 1 月 27 日)

  • 新功能
  • 包括为 NVidia 计算能力 8.0 的 GPU(例如 A100)预编译的 SASS。删除了计算能力 6.1 的预编译 SASS,以避免增加计算能力的数量:具有计算能力 6.1 的 GPU 可以使用 6.0 的 SASS。
  • 使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。
  • Breaking changes
  • 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
  • Bug 修复
  • 修复了一个 bug,即由不同路径构造的表面相同的 pytreedef 对象不会被视为相等(#9066)。
  • JAX jit 缓存要求两个静态参数具有相同的类型以进行缓存命中(#9311)。

jax 0.2.27(2022 年 1 月 18 日)

  • 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
  • host_callback 原语已简化,取消了 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只有原始值被 tap。可以通过设置 JAX_HOST_CALLBACK_AD_TRANSFORMS 环境变量或 --jax_host_callback_ad_transforms 标志来获取旧的行为(在有限时间内)。此外,增加了如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。
  • 排序现在与 NumPy 的行为匹配,无论位表示如何,对于 0.0NaN 都是如此。特别是,现在 0.0-0.0 被视为等价,而之前 -0.0 被视为小于 0.0。此外,所有的 NaN 表示现在都被视为等价,并且按照这些位模式排序到数组的末尾。以前,负数的 NaN 值被排序到数组的前面,并且具有不同内部位表示的 NaN 值不被视为等价,根据这些位模式排序(#9178)。
  • jax.numpy.unique() 现在在处理 NaN 值时与 NumPy 版本 1.21 及更新版本的 np.unique 一样:在唯一化的输出中最多只会出现一个 NaN 值(#9184)。
  • Bug 修复:
  • 现在 host_callback 支持 ad_checkpoint.checkpoint(#8907)。
  • 新功能:
  • 添加了 jax.block_until_ready({jax-issue}`#8941)。
  • 添加了一个新的调试标志/环境变量 JAX_DUMP_IR_TO=/path。如果设置了,JAX 会将它为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件。
  • 添加了 jax.ensure_compile_time_eval 到公共 API(#7987)。
  • jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions,用于改变关联约简的降低,例如 jnp.cumsum,在 CPU 和 GPU 上的行为(使用关联扫描)。更多细节请参见 jax2tf README(#9189)。

jaxlib 0.1.75(2021 年 12 月 8 日)

  • 新功能:
  • 支持 python 3.10。

jax 0.2.26(2021 年 12 月 8 日)

  • jax.ops.segment_sum 的越界索引现在将使用 FILL_OR_DROP 语义处理,如文档中所述。这主要影响反向模式导数,其中与越界索引对应的梯度现在将返回为 0。(#8634)。
  • jax2tf 现在会强制转换代码,使其在 jax.jit 下的代码片段使用 XLA,例如大多数 jax.numpy 函数(#7839)。

jaxlib 0.1.74(2021 年 11 月 17 日)

  • 在 GPU 之间启用点对点复制。以前,GPU 复制通过主机反弹,这通常更慢。
  • 增加了实验性的 MLIR Python 绑定,供 JAX 使用。

jax 0.2.25(2021 年 11 月 10 日)

  • (实验性)jax.distributed.initialize 暴露多主机 GPU 后端。
  • jax.random.permutation 支持新的 independent 关键字参数(#8430
  • 破坏性更改
  • jax.experimental.stax 移至 jax.example_libraries.stax
  • jax.experimental.optimizers 移至 jax.example_libraries.optimizers
  • 新功能:
  • 添加了 jax.lax.linalg.qdwh

jax 0.2.24(2021 年 10 月 19 日)

  • jax.random.choicejax.random.permutation 现在支持多维数组和可选的 axis 参数(#8158)。
  • 破坏性更改:
  • 现在 jax.numpy.takejax.numpy.take_along_axis 要求数组样式的输入(参见 #7737)。

jaxlib 0.1.73(2021 年 10 月 18 日)

  • 现在支持多个 cuDNN 版本的 jaxlib GPU cuda11 轮。
  • cuDNN 8.2 或更新版本。如果您的 cuDNN 安装足够新,请使用 cuDNN 8.2 轮,因为它支持额外的功能。
  • cuDNN 8.0.5 或更新版本。
  • 破坏性更改:
  • GPU jaxlib 的安装命令如下:
pip  install  --upgrade  pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip  install  --upgrade  "jax[cuda]"  -f  https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip  install  jax[cuda11_cudnn82]  -f  https://storage.googleapis.com/jax-releases/jax_releases.html
# Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip  install  jax[cuda11_cudnn805]  -f  https://storage.googleapis.com/jax-releases/jax_releases.html 

jax 0.2.22(2021 年 10 月 12 日)

  • jax.pmap 的静态参数现在必须是可哈希的。
    jax.jit 上长期不允许非哈希静态参数,但在 jax.pmap 上仍然允许;jax.pmap 使用对象标识比较非哈希静态参数。
    这种行为可能会导致一些问题,因为使用对象身份比较来比较参数会导致每次对象身份变化时重新编译。现在我们禁止非可哈希参数:如果 jax.pmap 的用户希望通过对象身份比较静态参数,他们可以在其对象上定义 __hash____eq__ 方法,或者将其对象包装在具有对象身份语义的对象中。另一种选择是使用 functools.partial 将非可哈希的静态参数封装到函数对象中。
  • jax.util.partial 是一个意外导出的内容,已被移除。请使用 Python 标准库中的 functools.partial 替代。
  • Deprecations
  • 函数 jax.ops.index_updatejax.ops.index_add 等已被弃用,并将在未来的 JAX 版本中移除。请改用 JAX 数组上的 .at 属性,例如 x.at[idx].set(y)。目前,这些函数会产生 DeprecationWarning
  • New features:
  • 优化的 C++ 代码路径现在是使用 jaxlib 0.1.72 或更新版本时的默认设置,用于提高 pmap 的调度时间。可以使用 --experimental_cpp_pmap 标志(或 JAX_CPP_PMAP 环境变量)禁用该功能。
  • jax.numpy.unique 现在支持一个可选的 fill_value 参数(#8121)。

jaxlib 0.1.72 (Oct 12, 2021)

  • Breaking changes:
  • CUDA 10.2 和 CUDA 10.1 的支持已被移除。Jaxlib 现在支持 CUDA 11.1+。
  • Bug fixes:

jax 0.2.21 (Sept 23, 2021)

  • jax.api 已被移除。之前作为 jax.api.* 可用的函数现在被别名为 jax.* 中的函数;请直接使用 jax.* 中的函数。
  • jax.partialjax.lax.partial 是意外导出的内容,已被移除。请使用 Python 标准库中的 functools.partial 替代。
  • 布尔标量索引现在会引发 TypeError;之前这些操作会静默返回错误的结果(#7925)。
  • 许多 jax.numpy 函数现在要求数组样式的输入,如果传递列表将会报错(#7747 #7802 #7907)。查看 #7737 以了解此更改背后的原因讨论。
  • 当在 jax.jit 等转换内部时,jax.numpy.array 总是将其生成的数组分阶段到跟踪的计算中。以前的 jax.numpy.array 有时会在 jax.jit 装饰器下生成一个设备上的数组。这种变化可能会破坏使用 JAX 数组执行必须静态知道形状或索引计算的代码;解决方法是改用经典的 NumPy 数组执行这些计算。
  • jnp.ndarray 现在是 JAX 数组的真正基类。特别地,对于标准的 numpy 数组 xisinstance(x, jnp.ndarray) 现在会返回 False (#7927)。
  • 新特性:
  • 添加了 jax.numpy.insert() 的实现 (#7936)。

jax 0.2.20 (2021 年 9 月 2 日)

  • jnp.poly* 函数现在要求数组样式的输入 (#7732)。
  • jnp.unique 和其他类似集合的操作现在要求数组样式的输入 (#7662)。

jaxlib 0.1.71 (2021 年 9 月 1 日)

  • Breaking changes:
  • 不再支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。

jax 0.2.19 (2021 年 8 月 12 日)

  • 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
  • 在 JAX 数组的多个操作的实现周围添加了 jit 装饰器。这加快了常见操作如 + 的调度时间。
    这个变化对大多数用户基本上是透明的。但是,有一个已知的行为变化,即直接传递给 JAX 操作符的大整数常数现在可能会产生错误(例如 x + 2**40)。解决方法是将常数转换为显式类型(例如 np.float64(2**40))。
  • 新特性:
  • 改进了对需要在数组计算中使用维度大小的操作在 jax2tf 中的形状多态支持,例如 jnp.mean。 (#7317)。
  • Bug 修复:
  • 上一个版本的泄漏的追踪错误 (#7613)。

jaxlib 0.1.70 (2021 年 8 月 9 日)

  • Breaking changes:
  • 支持 Python 3.6 已经被废弃,按照弃用政策。请升级到支持的 Python 版本。
  • 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
  • 现在主机回调机制每个本地设备使用一个线程来调用 Python 回调。以前所有设备共用一个线程。这意味着现在回调可能交错调用。仍然会按顺序调用一个设备对应的所有回调。

jax 0.2.18(2021 年 7 月 21 日)

  • 根据弃用策略,不再支持 Python 3.6。请升级到支持的 Python 版本。
  • jaxlib 最低版本现在是 0.1.69。
  • jax.dlpack.from_dlpack()backend 参数已移除。
  • 新功能:
  • 添加了极分解(jax.scipy.linalg.polar())。
  • Bug 修复:
  • 加强了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会使用无效的 axis 值或空的减少维度。 (#7196


JAX 中文文档(十六)(5)https://developer.aliyun.com/article/1559732

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
3月前
|
并行计算 算法框架/工具 异构计算
JAX 中文文档(十六)(5)
JAX 中文文档(十六)
44 2
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
34 3
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
77 1
|
3月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
30 1
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
40 1
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(3)
JAX 中文文档(十六)
55 0
|
3月前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
22 0
|
3月前
|
安全 API 网络架构
JAX 中文文档(十五)(1)
JAX 中文文档(十五)
37 0