JAX 中文文档(十六)(3)https://developer.aliyun.com/article/1559729
jax 0.3.15(2022 年 7 月 22 日)
- GitHub 提交记录。
- 变更
jax.test_util中已移除JaxTestCase和JaxTestLoader类,自 v0.3.1 起已弃用(#11248)。- 添加了
jax.scipy.gaussian_kde(#11237)。 - JAX 数组与内置集合(
dict、list、set、tuple)之间的二元操作现在在所有情况下都会引发TypeError。以前的某些情况(特别是相等性和不等式)会返回与 NumPy 中类似操作不一致的布尔标量(#11234)。 - 几个作为顶级 JAX 包导入的
jax.tree_util例程现已弃用,并将根据 API 兼容性政策在未来的 JAX 发布版本中移除。
jax.treedef_is_leaf()已弃用,推荐使用jax.tree_util.treedef_is_leaf()。jax.tree_flatten()已弃用,推荐使用jax.tree_util.tree_flatten()。jax.tree_leaves()已弃用,推荐使用jax.tree_util.tree_leaves()。jax.tree_structure()已弃用,推荐使用jax.tree_util.tree_structure()。jax.tree_transpose()已弃用,推荐使用jax.tree_util.tree_transpose()。jax.tree_unflatten()已弃用,推荐使用jax.tree_util.tree_unflatten()。
jax.scipy.linalg.solve()的sym_pos参数已弃用,推荐使用assume_a='pos',遵循scipy.linalg.solve()中类似的弃用。
jaxlib 0.3.15(2022 年 7 月 22 日)
jax 0.3.14(2022 年 6 月 27 日)
- GitHub 提交。
- 破坏性变更
jax.experimental.compilation_cache.initialize_cache()现在不再支持max_cache_size_ bytes,并且不会将其作为输入。- 当平台初始化失败时,
JAX_PLATFORMS现在会引发异常。
- 变更
- 解决了与 NumPy 1.23 的兼容性问题。
jax.numpy.linalg.slogdet()现在接受一个可选的method参数,允许选择基于 LU 分解或基于 QR 分解的实现。jax.numpy.linalg.qr()现在支持mode="raw"。- 在对 JAX 数组使用
pickle、copy.copy和copy.deepcopy时,现在支持更完整的支持(#10659)。特别是:
- 当对
DeviceArray使用pickle和deepcopy时,以前返回np.ndarray对象,现在返回DeviceArray对象。对于deepcopy,复制的数组位于与原始数组相同的设备上。对于pickle,反序列化的数组将位于默认设备上。 - 在函数转换(即跟踪代码)内部,
deepcopy和copy以前是空操作。现在它们使用与DeviceArray.copy()相同的机制。 - 对跟踪数组进行
pickle操作现在会导致显式的ConcretizationTypeError。
- 在 TPU 上,奇异值分解(SVD)和对称/Hermitian 特征分解的实现应显著更快,特别是对于超过 1000x1000 大小的矩阵。现在都使用了谱分裂与征算法进行特征分解(QDWH-eig)。
jax.numpy.ldexp()现在不再将所有输入默认提升为 float64,而是对于 int32 或更小的整数输入,提升为 float32 (#10921)。- 添加了一个
create_perfetto_link选项到jax.profiler.start_trace()和jax.profiler.start_trace()。使用时,分析器将生成一个链接到 Perfetto UI 以查看跟踪信息。 - 更改了
jax.profiler.start_server(...)()的语义,将 keepalive 全局存储,而不再要求用户保留引用。 - 添加了
jax.random.generalized_normal()。 - 添加了
jax.random.ball()。 - 添加了
jax.default_device()。 - 添加了一个
python -m jax.collect_profile脚本,手动捕获程序跟踪,作为 TensorBoard UI 的替代方法。 - 添加了一个
jax.named_scope上下文管理器,向 Python 程序添加分析器元数据(类似于jax.named_call)。 - 在 scatter-update 操作(即 :attr:
jax.numpy.ndarray.at)中,不安全的隐式 dtype 转换已弃用,现在会产生FutureWarning。在将来的版本中,这将变成一个错误。一个不安全的隐式转换的例子是jnp.zeros(4, dtype=int).at[0].set(1.5),其中1.5之前会被静默截断为1。 jax.experimental.compilation_cache.initialize_cache()现在支持 gcs 存储桶路径作为输入。- 添加了
jax.scipy.stats.gennorm()。 jax.numpy.roots()现在在strip_zeros=False时,在系数有前导零时行为更佳 (#11215)。
jaxlib 0.3.14(2022 年 6 月 27 日)。
- x86-64 Mac wheels 现在要求 Mac OS 10.14(Mojave)或更新版本。Mac OS 10.14 发布于 2018 年,因此这不应该是一个非常繁重的要求。
- 捆绑的 NCCL 版本更新到 2.12.12,修复了一些死锁问题。
- Python flatbuffers 包不再是 jaxlib 的依赖项。
jax 0.3.13(2022 年 5 月 16 日)。
jax 0.3.12(2022 年 5 月 15 日)。
- GitHub 提交记录。
- 变更:
- 修复了 #10717。
jax 0.3.11(2022 年 5 月 15 日)。
- GitHub 提交记录。
- 变更:
jax.lax.eigh()现在接受一个可选的sort_eigenvalues参数,允许用户在 TPU 上选择不排序特征值。
- 弃用:
jax.lax.linalg中的函数现在要求非数组参数必须作为关键字参数传递。为了向后兼容,将关键字参数作为位置参数传递将会得到警告,但在未来的 JAX 发布中,将会导致失败。大多数用户应该优先考虑使用jax.numpy.linalg。jax.scipy.linalg.polar_unitary(),这是 JAX 对 scipy API 的扩展,已被弃用。请改用jax.scipy.linalg.polar()。
jax 0.3.10 (2022 年 5 月 3 日)
jaxlib 0.3.10 (2022 年 5 月 3 日)
- GitHub 提交记录.
- 变更
- TF 提交记录 修复了 MHLO 规范化器中的问题,该问题导致某些程序的常量折叠花费很长时间或崩溃。
jax 0.3.9 (2022 年 5 月 2 日)
- GitHub 提交记录.
- 变更
- 增加了对 GlobalDeviceArray 的完全异步检查点支持。
jax 0.3.8 (2022 年 4 月 29 日)
- GitHub 提交记录.
- 变更
- 在 TPU 上,
jax.numpy.linalg.svd()现在使用 qdwh-svd 求解器。 - 在 TPU 上,
jax.numpy.linalg.cond()现在接受复数输入。 - 在 TPU 上,
jax.numpy.linalg.pinv()现在接受复数输入。 - 在 TPU 上,
jax.numpy.linalg.matrix_rank()现在接受复数输入。 - 已添加
jax.scipy.cluster.vq.vq()。 jax.experimental.maps.mesh已删除。请使用jax.experimental.maps.Mesh。请参阅 此处 获取更多信息。- 当
mode='r'时,jax.scipy.linalg.qr()现在返回一个长度为 1 的元组,而不是原始数组,以匹配scipy.linalg.qr的行为(#10452) jax.numpy.take_along_axis()现在接受一个可选的mode参数,用于指定超出边界索引的行为。默认情况下,超出边界的索引会返回无效值(例如 NaN)。在 JAX 的早期版本中,无效的索引会被夹在范围内。可以通过传递mode="clip"恢复先前的行为。jax.numpy.take()现在默认为mode="fill",这会对超出索引范围的位置返回无效值(例如 NaN)。- 散点操作,例如
x.at[...].set(...),现在具有"drop"语义。这对散点操作本身没有影响,但这意味着在进行微分时,散点的梯度对超出边界的索引的余切为零。以前超出边界的索引在梯度中被夹在范围内,这在数学上是不正确的。 jax.numpy.take_along_axis()现在如果其索引不是整数类型将会引发TypeError,与numpy.take_along_axis()的行为一致。先前非整数索引会被静默转换为整数。jax.numpy.ravel_multi_index()现在如果其dims参数不是整数类型将会引发TypeError,与numpy.ravel_multi_index()的行为一致。先前非整数dims参数会被静默转换为整数。jax.numpy.split()现在如果其axis参数不是整数类型将会引发TypeError,与numpy.split()的行为一致。先前非整数axis参数会被静默转换为整数。jax.numpy.indices()现在如果其维度不是整数类型将会引发TypeError,与numpy.indices()的行为一致。先前非整数维度会被静默转换为整数。jax.numpy.diag()现在如果其k参数不是整数类型将会引发TypeError,与numpy.diag()的行为一致。先前非整数k参数会被静默转换为整数。- 添加了
jax.random.orthogonal()。
- 已过时:
- 许多
jax.test_util中可用的函数和对象现已过时,并将在导入时引发警告。包括cases_from_list、check_close、check_eq、device_under_test、format_shape_dtype_string、rand_uniform、skip_on_devices、with_config、xla_bridge和_default_tolerance(#10389)。这些以及先前过时的JaxTestCase、JaxTestLoader和BufferDonationTestCase将在未来的 JAX 发布中移除。大多数这些实用程序可以通过调用标准的 Python 和 NumPy 测试实用程序来替换,如unittest、absl.testing、numpy.testing等。可以通过公共 API(例如jax.devices())来替换 JAX 特定的功能,如设备检查。许多已过时的实用程序仍然存在于jax._src.test_util中,但这些不是公共 API,因此可能在未来的发布中更改或移除,而不另行通知。
jax 0.3.7(2022 年 4 月 15 日)
- GitHub 提交记录。
- 变更:
- 修复了当传递给
jax.numpy.take_along_axis()的索引广播时的性能问题(#10281)。 jax.scipy.special.expit()和jax.scipy.special.logit()现在要求其参数为标量或 JAX 数组。它们现在还将整数参数提升为浮点数。DeviceArray.tile()方法已弃用,因为 numpy 数组没有tile()方法。作为替代,请使用jax.numpy.tile()(#10266)。
jaxlib 0.3.7(2022 年 4 月 15 日)
- 变更:
- Linux 版本现在符合
manylinux2014标准,而不是manylinux2010。
jax 0.3.6(2022 年 4 月 12 日)
- GitHub 提交记录。
- 变更:
- 将 libtpu 轮子升级到修复初始化 TPU pod 时挂起的版本。修复了 #10218。
- 弃用:
jax.experimental.loops将被弃用。参见 #10278 了解替代 API。
jax 0.3.5(2022 年 4 月 7 日)
- GitHub 提交记录。
- 变更:
- 添加了
jax.random.loggamma()并改进了对小参数值的jax.random.beta()和jax.random.dirichlet()的行为(#9906)。 lax_numpy私有子模块不再暴露在jax.numpy命名空间中(#10029)。- 添加了数组创建例程
jax.numpy.frombuffer()、jax.numpy.fromfunction()和jax.numpy.fromstring()(#10049)。 DeviceArray.copy()现在返回DeviceArray而不是np.ndarray(#10069)- 添加了
jax.scipy.linalg.rsf2csf() jax.experimental.sharded_jit已被弃用,并将很快移除。
- 弃用:
jax.nn.normalize()将被弃用。请使用jax.nn.standardize()替代(#9899)。jax.tree_util.tree_multimap()已弃用。请使用jax.tree_util.tree_map()替代(#5746)。jax.experimental.sharded_jit已弃用。请使用pjit替代。
jaxlib 0.3.5(2022 年 4 月 7 日)
- 修复了 bug
jax 0.3.4(2022 年 3 月 18 日)
jax 0.3.3(2022 年 3 月 17 日)
jax 0.3.2(2022 年 3 月 16 日)
- GitHub 提交记录。
- 变更:
- 函数
jax.ops.index_update、jax.ops.index_add在 0.2.22 中已弃用。请使用JAX 数组上的.at属性,例如,x.at[idx].set(y)。 - 将
jax.experimental.ann.approx_*_k移至jax.lax。这些函数是jax.lax.top_k的优化替代品。 jax.numpy.broadcast_arrays()和jax.numpy.broadcast_to()现在要求标量或类数组输入,并在传递列表时将失败(部分 #7737)。- 标准的
jax[tpu]安装现在可以与 Cloud TPU v4 VMs 一起使用。 pjit现在支持在 CPU 上运行(除了之前的 TPU 和 GPU 支持)。
jaxlib 0.3.2 (2022 年 3 月 16 日)
- 更改
XlaComputation.as_hlo_text()现在支持通过传递布尔标志print_large_constants=True打印大常量。
- 弃用:
JAX数组上的.block_host_until_ready()方法已弃用。请改用.block_until_ready()。
jax 0.3.1 (2022 年 2 月 18 日)
- GitHub 提交记录。
- 更改:
jax.test_util.JaxTestCase和jax.test_util.JaxTestLoader现在已弃用。建议直接使用parametrized.TestCase进行替换。对于依赖于自定义断言(如JaxTestCase.assertAllClose())的测试,请使用标准的 numpy 测试工具,如numpy.testing.assert_allclose(),它们直接与 JAX 数组一起工作(#9620)。jax.test_util.JaxTestCase现在默认设置jax_numpy_rank_promotion='raise'(#9562)。要恢复以前的行为,请使用新的jax.test_util.with_config装饰器:
@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- 添加了
jax.scipy.linalg.schur()、jax.scipy.linalg.sqrtm()、jax.scipy.signal.csd()、jax.scipy.signal.stft()、jax.scipy.signal.welch()。
jax 0.3.0 (2022 年 2 月 10 日)
- GitHub 提交记录。
- 更改
- jax 版本已升级至 0.3.0. 请参阅设计文档以获取说明。
jaxlib 0.3.0 (2022 年 2 月 10 日)
- 更改
- 现在需要 Bazel 5.0.0 来构建 jaxlib。
- jaxlib 版本已升级至 0.3.0. 请参阅设计文档以获取说明。
jax 0.2.28 (2022 年 2 月 1 日)
- 如果未传递
dialect=,jax.jit(f).lower(...).compiler_ir()现在默认为 MHLO 方言。 jax.jit(f).lower(...).compiler_ir(dialect='mhlo')现在返回 MLIRir.Module对象,而不是其字符串表示。
jaxlib 0.1.76 (2022 年 1 月 27 日)
- 新功能
- 包括为 NVidia 计算能力 8.0 的 GPU(例如 A100)预编译的 SASS。删除了计算能力 6.1 的预编译 SASS,以避免增加计算能力的数量:具有计算能力 6.1 的 GPU 可以使用 6.0 的 SASS。
- 使用 jaxlib 0.1.76,JAX 默认使用 MHLO MLIR 方言作为其主要目标编译器 IR。
- Breaking changes
- 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
- Bug 修复
- 修复了一个 bug,即由不同路径构造的表面相同的 pytreedef 对象不会被视为相等(#9066)。
- JAX jit 缓存要求两个静态参数具有相同的类型以进行缓存命中(#9311)。
jax 0.2.27(2022 年 1 月 18 日)
- GitHub 提交。
- Breaking changes:
- 不再支持 NumPy 1.18,根据弃用策略。请升级到支持的 NumPy 版本。
- host_callback 原语已简化,取消了 hcb.id_tap 和 id_print 的特殊自动微分处理。从现在开始,只有原始值被 tap。可以通过设置
JAX_HOST_CALLBACK_AD_TRANSFORMS环境变量或--jax_host_callback_ad_transforms标志来获取旧的行为(在有限时间内)。此外,增加了如何使用 JAX 自定义 AD API 实现旧行为的文档(#8678)。 - 排序现在与 NumPy 的行为匹配,无论位表示如何,对于
0.0和NaN都是如此。特别是,现在0.0和-0.0被视为等价,而之前-0.0被视为小于0.0。此外,所有的NaN表示现在都被视为等价,并且按照这些位模式排序到数组的末尾。以前,负数的NaN值被排序到数组的前面,并且具有不同内部位表示的NaN值不被视为等价,根据这些位模式排序(#9178)。 jax.numpy.unique()现在在处理NaN值时与 NumPy 版本 1.21 及更新版本的np.unique一样:在唯一化的输出中最多只会出现一个NaN值(#9184)。
- Bug 修复:
- 现在 host_callback 支持 ad_checkpoint.checkpoint(#8907)。
- 新功能:
- 添加了
jax.block_until_ready({jax-issue}`#8941)。 - 添加了一个新的调试标志/环境变量
JAX_DUMP_IR_TO=/path。如果设置了,JAX 会将它为每个计算生成的 MHLO/HLO IR 转储到给定路径下的文件。 - 添加了
jax.ensure_compile_time_eval到公共 API(#7987)。 - jax2tf 现在支持一个标志 jax2tf_associative_scan_reductions,用于改变关联约简的降低,例如 jnp.cumsum,在 CPU 和 GPU 上的行为(使用关联扫描)。更多细节请参见 jax2tf README(#9189)。
jaxlib 0.1.75(2021 年 12 月 8 日)
- 新功能:
- 支持 python 3.10。
jax 0.2.26(2021 年 12 月 8 日)
- GitHub 提交记录。
- 错误修复:
- 对
jax.ops.segment_sum的越界索引现在将使用FILL_OR_DROP语义处理,如文档中所述。这主要影响反向模式导数,其中与越界索引对应的梯度现在将返回为 0。(#8634)。 - jax2tf 现在会强制转换代码,使其在 jax.jit 下的代码片段使用 XLA,例如大多数 jax.numpy 函数(#7839)。
jaxlib 0.1.74(2021 年 11 月 17 日)
- 在 GPU 之间启用点对点复制。以前,GPU 复制通过主机反弹,这通常更慢。
- 增加了实验性的 MLIR Python 绑定,供 JAX 使用。
jax 0.2.25(2021 年 11 月 10 日)
- GitHub 提交记录。
- 新功能:
- (实验性)
jax.distributed.initialize暴露多主机 GPU 后端。 jax.random.permutation支持新的independent关键字参数(#8430)
- 破坏性更改
- 将
jax.experimental.stax移至jax.example_libraries.stax - 将
jax.experimental.optimizers移至jax.example_libraries.optimizers
- 新功能:
- 添加了
jax.lax.linalg.qdwh。
jax 0.2.24(2021 年 10 月 19 日)
- GitHub 提交记录。
- 新功能:
jax.random.choice和jax.random.permutation现在支持多维数组和可选的axis参数(#8158)。
- 破坏性更改:
- 现在
jax.numpy.take和jax.numpy.take_along_axis要求数组样式的输入(参见 #7737)。
jaxlib 0.1.73(2021 年 10 月 18 日)
- 现在支持多个 cuDNN 版本的 jaxlib GPU
cuda11轮。
- cuDNN 8.2 或更新版本。如果您的 cuDNN 安装足够新,请使用 cuDNN 8.2 轮,因为它支持额外的功能。
- cuDNN 8.0.5 或更新版本。
- 破坏性更改:
- GPU jaxlib 的安装命令如下:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22(2021 年 10 月 12 日)
- GitHub 提交记录。
- 破坏性更改
jax.pmap的静态参数现在必须是可哈希的。
在jax.jit上长期不允许非哈希静态参数,但在jax.pmap上仍然允许;jax.pmap使用对象标识比较非哈希静态参数。
这种行为可能会导致一些问题,因为使用对象身份比较来比较参数会导致每次对象身份变化时重新编译。现在我们禁止非可哈希参数:如果jax.pmap的用户希望通过对象身份比较静态参数,他们可以在其对象上定义__hash__和__eq__方法,或者将其对象包装在具有对象身份语义的对象中。另一种选择是使用functools.partial将非可哈希的静态参数封装到函数对象中。jax.util.partial是一个意外导出的内容,已被移除。请使用 Python 标准库中的functools.partial替代。
- Deprecations
- 函数
jax.ops.index_update、jax.ops.index_add等已被弃用,并将在未来的 JAX 版本中移除。请改用 JAX 数组上的.at属性,例如x.at[idx].set(y)。目前,这些函数会产生DeprecationWarning。
- New features:
- 优化的 C++ 代码路径现在是使用 jaxlib 0.1.72 或更新版本时的默认设置,用于提高
pmap的调度时间。可以使用--experimental_cpp_pmap标志(或JAX_CPP_PMAP环境变量)禁用该功能。 jax.numpy.unique现在支持一个可选的fill_value参数(#8121)。
jaxlib 0.1.72 (Oct 12, 2021)
- Breaking changes:
- CUDA 10.2 和 CUDA 10.1 的支持已被移除。Jaxlib 现在支持 CUDA 11.1+。
- Bug fixes:
- 修复了 https://github.com/google/jax/issues/7461,在所有平台上由于 XLA 编译器内部的错误缓冲区别名而导致错误的输出。
jax 0.2.21 (Sept 23, 2021)
- GitHub commits.
- Breaking Changes
jax.api已被移除。之前作为jax.api.*可用的函数现在被别名为jax.*中的函数;请直接使用jax.*中的函数。jax.partial和jax.lax.partial是意外导出的内容,已被移除。请使用 Python 标准库中的functools.partial替代。- 布尔标量索引现在会引发
TypeError;之前这些操作会静默返回错误的结果(#7925)。 - 许多
jax.numpy函数现在要求数组样式的输入,如果传递列表将会报错(#7747 #7802 #7907)。查看 #7737 以了解此更改背后的原因讨论。 - 当在
jax.jit等转换内部时,jax.numpy.array总是将其生成的数组分阶段到跟踪的计算中。以前的jax.numpy.array有时会在jax.jit装饰器下生成一个设备上的数组。这种变化可能会破坏使用 JAX 数组执行必须静态知道形状或索引计算的代码;解决方法是改用经典的 NumPy 数组执行这些计算。 jnp.ndarray现在是 JAX 数组的真正基类。特别地,对于标准的 numpy 数组x,isinstance(x, jnp.ndarray)现在会返回False(#7927)。
- 新特性:
- 添加了
jax.numpy.insert()的实现 (#7936)。
jax 0.2.20 (2021 年 9 月 2 日)
- GitHub 提交记录。
- Breaking Changes
jaxlib 0.1.71 (2021 年 9 月 1 日)
- Breaking changes:
- 不再支持 CUDA 11.0 和 CUDA 10.1。Jaxlib 现在支持 CUDA 10.2 和 CUDA 11.1+。
jax 0.2.19 (2021 年 8 月 12 日)
- GitHub 提交记录。
- Breaking changes:
- 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
- 在 JAX 数组的多个操作的实现周围添加了
jit装饰器。这加快了常见操作如+的调度时间。
这个变化对大多数用户基本上是透明的。但是,有一个已知的行为变化,即直接传递给 JAX 操作符的大整数常数现在可能会产生错误(例如x + 2**40)。解决方法是将常数转换为显式类型(例如np.float64(2**40))。
- 新特性:
- 改进了对需要在数组计算中使用维度大小的操作在 jax2tf 中的形状多态支持,例如
jnp.mean。 (#7317)。
- Bug 修复:
- 上一个版本的泄漏的追踪错误 (#7613)。
jaxlib 0.1.70 (2021 年 8 月 9 日)
- Breaking changes:
- 支持 Python 3.6 已经被废弃,按照弃用政策。请升级到支持的 Python 版本。
- 支持 NumPy 1.17 已经被废弃,按照弃用政策。请升级到支持的 NumPy 版本。
- 现在主机回调机制每个本地设备使用一个线程来调用 Python 回调。以前所有设备共用一个线程。这意味着现在回调可能交错调用。仍然会按顺序调用一个设备对应的所有回调。
jax 0.2.18(2021 年 7 月 21 日)
- GitHub 提交记录。
- Breaking 变更:
- 根据弃用策略,不再支持 Python 3.6。请升级到支持的 Python 版本。
- jaxlib 最低版本现在是 0.1.69。
jax.dlpack.from_dlpack()的backend参数已移除。
- 新功能:
- 添加了极分解(
jax.scipy.linalg.polar())。
- Bug 修复:
- 加强了对 lax.argmin 和 lax.argmax 的检查,以确保它们不会使用无效的
axis值或空的减少维度。 (#7196)
JAX 中文文档(十六)(5)https://developer.aliyun.com/article/1559732