JAX 中文文档(十六)(4)https://developer.aliyun.com/article/1559730
jaxlib 0.1.69(2021 年 7 月 9 日)
- 修复了 TFRT CPU 后端中导致结果不正确的错误。
jax 0.2.17(2021 年 7 月 9 日)
- GitHub 提交记录。
- Bug 修复:
- 对于 jaxlib <= 0.1.68,默认使用较旧的“stream_executor” CPU 运行时,以解决#7229,这导致 CPU 上由于并发问题输出错误结果。
- 新功能:
- 新的 SciPy 函数
jax.scipy.special.sph_harm()
。 - 反向模式自动微分函数(
jax.grad()
,jax.value_and_grad()
,jax.vjp()
和jax.linear_transpose()
)支持一个参数,指示在后向传递中应该对哪些命名轴进行求和,如果它们在前向传递中被广播。这使得可以在 maps 内部以非每个示例的方式使用这些 API(最初仅jax.experimental.maps.xmap()
)(#6950)。
jax 0.2.16(2021 年 6 月 23 日)
jax 0.2.15(2021 年 6 月 23 日)
- GitHub 提交记录。
- 新功能:
- #7042 使用了 TFRT CPU 后端,在 CPU 上显著提升了分派性能。
jax2tf.convert()
支持布尔型不等式和 min/max 函数(#6956)。- 新的 SciPy 函数
jax.scipy.special.lpmn_values()
。
- Breaking 变更:
- 根据弃用策略,不再支持 NumPy 1.16。
- Bug 修复:
- 修复了阻止从 JAX 到 TF 再到 JAX 回传的错误:
jax2tf.call_tf(jax2tf.convert)
(#6947)。
jaxlib 0.1.68(2021 年 6 月 23 日)
- Bug 修复:
- 修复了 TFRT CPU 后端中将 TPU 缓冲区传输到 CPU 时出现 NaN 的错误。
jax 0.2.14(2021 年 6 月 10 日)
- GitHub 提交记录。
- 新功能:
jax2tf.convert()
现在支持pjit
和sharded_jit
。- 新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯信息。
- 在足够新的 IPython 版本中,默认启用了使用
__tracebackhide__
的新的回溯过滤模式。 jax2tf.convert()
在算术操作中使用未知维度时,即使在形状多态性中,也支持形状多态性,例如jnp.reshape(-1)
(#6827)。jax2tf.convert()
现在在 TF 操作中生成具有位置信息的自定义属性。在 jax2tf 之后 XLA 生成的代码具有与 JAX/XLA 相同的位置信息。- 新的 SciPy 函数
jax.scipy.special.lpmn()
。
- Bug fixes:
jax2tf.convert()
现在确保对于 Python 标量和选择 32 位 vs. 64 位计算时使用相同的类型规则,如 JAX(#6883)。jax2tf.convert()
现在正确地将enable_xla
转换参数限定范围到仅在即时转换期间应用(#6720)。jax2tf.convert()
现在使用XlaDot
TensorFlow 操作来转换lax.dot_general
,以提高与 JAX 数值精度的一致性(#6717)。jax2tf.convert()
现在支持复数的不等式比较和最小/最大值(#6892)。
jaxlib 0.1.67(2021 年 5 月 17 日)
jaxlib 0.1.66(2021 年 5 月 11 日)
- 新特性:
- 现在支持在所有 CUDA 11 版本(11.1 或更高版本)上使用 CUDA 11.1 wheels。
NVIDIA 现在承诺从 CUDA 11.1 开始兼容 CUDA 小版本更新。这意味着 JAX 可以发布一个兼容 CUDA 11.2 和 11.3 的单个 CUDA 11.1 wheel。
不再为 CUDA 11.2(或更高版本)发布单独的 jaxlib 版本;对于这些版本,请使用 CUDA 11.1 wheel(cuda111)。 - Jaxlib 现在在 CUDA wheels 中捆绑
libdevice.10.bc
。不需要指定 CUDA 安装路径来查找此文件。 jit()
实现自动支持静态关键字参数。- 添加了对预转换异常跟踪的支持。
- 初步支持从
jit()
转换的计算中剪枝未使用的参数。剪枝仍在进行中。 - 改进了
PyTreeDef
对象的字符串表示。 - 添加了对 XLA 可变 ReduceWindow 的支持。
- Bug fixes:
- 修复了在远程云 TPU 支持中传递大量参数时的 bug。
- 修复了一个问题,即
jit()
转换的函数未触发 JAX 垃圾回收。
jax 0.2.13(2021 年 5 月 3 日)
- GitHub 提交。
- 新特性:
- 结合 jaxlib 0.1.66 使用时,
jax.jit()
现在支持静态关键字参数。新增了static_argnames
选项以指定关键字参数为静态。 jax.nonzero()
现在有一个新的可选参数size
,允许在jit
内使用 (#6501)。jax.numpy.unique()
现在支持axis
参数 (#6532)。jax.experimental.host_callback.call()
现在支持pjit.pjit
(#6569)。- 添加了
jax.scipy.linalg.eigh_tridiagonal()
,用于计算三对角矩阵的特征值。目前仅支持特征值。 - 异常中筛选和未筛选的堆栈跟踪顺序已更改。从 JAX 转换代码中抛出的异常现在附带有过滤后的回溯,
UnfilteredStackTrace
异常包含原始跟踪作为过滤异常的__cause__
。现在,筛选的堆栈跟踪也适用于 Python 3.6。 - 如果由反向模式自动微分转换的代码引发异常,JAX 现在尝试附加一个
JaxStackTraceBeforeTransformation
对象作为异常的__cause__
,该对象包含在正向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
- 破坏性变更:
- 下列函数名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
- 同样地,
local_devices()
的参数已从host_id
重命名为process_index
。 - 除了函数之外的
jax.jit()
参数现在标记为仅限关键字。此更改旨在防止在向jit
添加参数时意外破坏代码。
- Bug 修复:
- 现在
jax2tf.convert()
在带有整数输入的函数梯度存在时能正常工作 (#6360)。 - 修复了
jax2tf.call_tf()
在与捕获的tf.Variable
结合使用时的断言失败 (#6572)。
jaxlib 0.1.65(2021 年 4 月 7 日)
jax 0.2.12(2021 年 4 月 1 日)
- GitHub 提交记录。
- 新功能
- 新的分析 API:
jax.profiler.start_trace()
,jax.profiler.stop_trace()
和jax.profiler.trace()
jax.lax.reduce()
现在可微分。
- 破坏性变更:
- 最低的 jaxlib 版本现在是 0.1.64。
- 一些分析器 API 名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
- 无法禁用全局分析。有关更多信息,请参阅 omnistaging。
- Python 整数大于最大的
int64
值现在在所有情况下都会导致溢出,而不是在某些情况下静默转换为uint64
(#6047)。 - 在非 X64 模式下,超出
int32
可表示范围的 Python 整数现在将导致OverflowError
,而不是静默截断其值。
- Bug 修复:
host_callback
现在支持参数和结果中的空数组(#6262)。jax.random.randint()
在超出限制范围时会剪切而不是包裹,现在可以生成指定 dtype 的整数的全部范围(#5868)。
jax 0.2.11(2021 年 3 月 23 日)
- GitHub 提交记录。
- 新特性:
- #6112 添加了上下文管理器:
jax.enable_checks
,jax.check_tracer_leaks
,jax.debug_nans
,jax.debug_infs
,jax.log_compiles
。 - #6085 添加了
jnp.delete
- Bug 修复:
- #6136 泛化了
jax.flatten_util.ravel_pytree
以处理整数 dtype。 - #6129 修复了处理像
enum.IntEnums
这样的一些常量的错误 - #6145 修复了不完全贝塔函数批处理问题
- #6014 修复了追踪过程中的 H2D 传输问题
- #6165 在将一些大的 Python 整数转换为浮点数时避免 OverflowErrors
- 破坏性变更:
- jaxlib 最小版本现在是 0.1.62。
jaxlib 0.1.64(2021 年 3 月 18 日)
jaxlib 0.1.63(2021 年 3 月 17 日)
jax 0.2.10(2021 年 3 月 5 日)
- GitHub 提交记录。
- 新特性:
jax.scipy.stats.chi2()
现在作为具有 logpdf 和 pdf 方法的分布可用。jax.scipy.stats.betabinom()
现在作为具有 logpmf 和 pmf 方法的分布可用。- 添加了
jax.experimental.jax2tf.call_tf()
以从 JAX 调用 TensorFlow 函数(#5627)和README。 - 扩展了
lax.pad
的批处理规则以支持填充值的批处理。
- Bug 修复:
jax.numpy.take()
正确处理负索引(#5768)
- 破坏性变更:
- 调整了 JAX 的提升规则,使提升更一致且不受 JIT 影响。特别是,当适当时,二进制操作现在可以产生弱类型值。更改的主要用户可见效果是某些操作的输出精度与之前不同;例如表达式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
以前返回float64
数组,现在返回bfloat16
数组。JAX 的类型提升行为在类型提升语义中描述。 jax.numpy.linspace()
现在计算整数值的地板,即向负无穷取整,而不是向 0 取整。此更改是为了与 NumPy 1.20.0 保持一致。jax.numpy.i0()
不再接受复数。之前该函数计算复数参数的绝对值。此更改是为了与 NumPy 1.20.0 的语义保持一致。- 几个
jax.numpy
函数不再接受元组或列表作为数组参数的替代:jax.numpy.pad()
,jax.numpy.ravel
,jax.numpy.repeat()
,jax.numpy.reshape()
。通常情况下,应使用标量或数组参数调用jax.numpy
函数。
jaxlib 0.1.62 (2021 年 3 月 9 日)
- 新特性:
- 在 x86-64 机器上,默认情况下构建 jaxlib wheels 需要 AVX 指令。如果要在不支持 AVX 的机器上使用 JAX,可以使用
build.py
的--target_cpu_features
标志从源代码构建 jaxlib。--target_cpu_features
还替换了--enable_march_native
。
jaxlib 0.1.61 (2021 年 2 月 12 日)
jaxlib 0.1.60 (2021 年 2 月 3 日)
- 错误修复:
- 修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏问题。在 jaxlib 发布的 0.1.58 和 0.1.59 版本中存在该内存泄漏。
bool
,int8
和uint8
现在被认为是安全的,可以转换为bfloat16
NumPy 扩展类型。
jax 0.2.9 (2021 年 1 月 26 日)
- GitHub 提交记录.
- 新特性:
- 扩展
jax.experimental.loops
模块以支持 pytrees。改进了错误检查和错误消息。 - 添加
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
- 破坏性变更:
jax.ops.segment_sum()
现在在性能考虑下删除超出范围的段 ID,而不是将它们包装到段 ID 空间。
jaxlib 0.1.59 (2021 年 1 月 15 日)
jax 0.2.8 (2021 年 1 月 12 日)
- GitHub 提交记录.
- 新特性:
- 添加
jax.closure_convert()
用于与高阶自定义导数函数一起使用。 (#5244) - 添加
jax.experimental.host_callback.call()
以调用主机上的自定义 Python 函数并将结果返回到设备计算中。 (#5243)
- 错误修复:
jax.numpy.arccosh
现在对复数输入返回与numpy.arccosh
相同的分支(#5156)。- 现在
host_callback.id_tap
在jax.pmap
中也可以使用。对于id_tap
和id_print
,现在有一个可选参数,可以请求将值从中提取的设备作为关键字参数传递给 tap 函数(#5182)。
- 破坏性更改:
jax.numpy.pad
现在接受关键字参数。位置参数constant_values
已被移除。此外,传递不受支持的关键字参数将引发错误。jax.experimental.host_callback.id_tap()
的更改(#5243):
- 删除了对
jax.experimental.host_callback.id_tap()
的kwargs
支持(这种支持已经被弃用几个月了)。 - 更改了
jax.experimental.host_callback.id_print()
中元组的打印方式,使用了(
而不是‘‘
。 - 在
jax.experimental.host_callback.id_print()
存在 JVP 的情况下,更改了打印元组的方式,现在使用了一对主元和切线。以前是分别打印主元和切线。 - 删除了
host_callback.outfeed_receiver
(这不再需要,并且几个月前已被弃用)。
- 新功能:
- 为
inf
的调试添加了一个新标志,类似于NaN
的标志(#5224)。
jax 0.2.7(2020 年 12 月 4 日)
- GitHub 提交。
- 新功能:
- 添加了
jax.device_put_replicated
。 - 向
jax.experimental.sharded_jit
添加了多主机支持。 - 增加对
jax.numpy.linalg.eig
计算的特征值的微分支持。 - 增加了对在 Windows 平台上构建的支持。
- 在
jax.pmap
中添加了对通用in_axes
和out_axes
的支持。 - 添加了对
jax.numpy.linalg.slogdet
的复数支持。
- Bug 修复:
- 修复
jax.numpy.sinc
在零点处高于二阶导数的问题。 - 修复了在转置规则中的符号零的一些难以命中的 bug。
- 破坏性更改:
- 已删除
jax.experimental.optix
,改为独立的optax
Python 包。 - 使用非元组序列索引 JAX 数组现在会引发
TypeError
。这种类型的索引自从 Numpy v1.16 和 JAX v0.2.4 开始已经被弃用。参见 #4564。
jax 0.2.6(2020 年 11 月 18 日)
- GitHub 提交。
- 新功能:
- 为
jax.experimental.jax2tf
转换器的形状多态跟踪添加了支持。参见 README.md。
- 破坏性更改清理:
- 将复杂的 Python 标量添加到 JAX 浮点数会保留 JAX 浮点数的精度。例如,
jnp.float32(1) + 1j
现在返回complex64
,而之前返回的是complex128
。 - 当涉及到包含
uint64
、有符号整型和第三种类型的三个或更多术语的类型提升时,现在与参数顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都返回float16
,之前第一个返回float64
,第二个返回float16
。
- (未记录的)
jax.lax_linalg
线性代数模块现在公开为jax.lax.linalg
。 jax.random.PRNGKey
现在在 JIT 编译内外产生相同的结果 (#4877)。这需要在几个特定情况下更改给定种子的结果:
- 使用
jax_enable_x64=False
时,作为 Python 整数传递的负数种子现在在 JIT 模式外返回不同的结果。例如,jax.random.PRNGKey(-1)
以前返回[4294967295, 4294967295]
,现在返回[0, 4294967295]
。这与 JIT 中的行为一致。 - JIT 外部的
int64
不能表示的范围外的种子现在会导致OverflowError
而不是TypeError
。这与 JIT 中的行为一致。
- 要恢复在
jax_enable_x64=False
时以前针对负整数返回的键,可以使用:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
- 当尝试访问已删除其值的
DeviceArray
时,现在会引发RuntimeError
而不是ValueError
。
jaxlib 0.1.58 (2021 年 1 月 12 日)
- 修复了 JAX 有时返回平台特定类型(如
np.cint
)而不是标准类型(如np.int32
)的 Bug (#4903)。 - 修复了在执行某些 int16 操作时常量折叠导致崩溃的问题 (#4971)。
- 在
pytree.flatten()
中添加了一个is_leaf
谓词。
jaxlib 0.1.57 (2020 年 11 月 12 日)
- 修复了 GPU wheels 中的 manylinux2010 兼容性问题。
- 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
- 修复了 bfloat16 值哈希未正确初始化并可能更改的 Bug (#4651)。
- 添加了对将数组传递给 DLPack 时保留所有权的支持 (#4636)。
- 修复了批量三角求解的一个 Bug,对大于 128 但不是 128 的倍数的情况。
- 修复了在多个 GPU 上同时进行并发 FFT 时的 Bug (#3518)。
- 在分析器中修复了工具缺失的 Bug (#4427)。
- 放弃了对 CUDA 10.0 的支持。
jax 0.2.5 (2020 年 10 月 27 日)
- GitHub 提交记录。
- 改进:
- 确保
check_jaxpr
不执行 FLOPS。参见 #4650。 - 扩展了由 jax2tf 转换的 JAX 原语集。参见 primitives_with_limited_support.md。
jax 0.2.4 (2020 年 10 月 19 日)
- GitHub 提交记录。
- 改进:
- 为
jax.experimental.host_callback
添加了对remat
的支持。参见 #4608。
- 弃用
- 现在,使用非元组序列进行索引已被弃用,遵循 Numpy 中的类似弃用。在将来的版本中,这将导致 TypeError。参见 #4564。
jaxlib 0.1.56 (2020 年 10 月 14 日)。
jax 0.2.3 (2020 年 10 月 14 日)。
- GitHub 提交记录。
- 由于需要暂时回退新的 jit 快速通路,因此又进行了一个新的发布。
jax 0.2.2 (2020 年 10 月 13 日)。
jax 0.2.1 (2020 年 10 月 6 日)。
- GitHub 提交记录。
- 改进:
- 作为全阶段的一个好处,即使
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的结果未在计算中使用,也会按程序顺序执行 host_callback 函数。
jax (0.2.0) (2020 年 9 月 23 日)。
- GitHub 提交记录。
- 改进:
- 默认情况下启用全阶段。参见 #3370 和 omnistaging。
jax (0.1.77) (2020 年 9 月 15 日)。
- 破坏性变更:
jax.experimental.host_callback.id_tap()
的新简化接口 (#4101)。
jaxlib 0.1.55 (2020 年 9 月 8 日)。
- 更新 XLA:
- 修复 DLPackManagedTensorToBuffer 中的错误 (#4196)。
jax 0.1.76 (2020 年 9 月 8 日)。
jax 0.1.75 (2020 年 7 月 30 日)。
- GitHub 提交记录。
- Bug 修复:
- 使 jnp.abs() 适用于无符号输入 (#3914)。
- 改进:
- 添加了“全阶段”行为,但在默认情况下已禁用 (#3370)。
jax 0.1.74 (2020 年 7 月 29 日)。
- GitHub 提交记录。
- 新功能:
- BFGS (#3101)。
- TPU 支持半精度算术 (#3878)。
- Bug 修复:
- 防止一些意外的 dtype 警告 (#3874)。
- 修复自定义导数中的多线程错误 (#3845, #3869)。
- 改进:
- 更快的 searchsorted 实现 (#3873)。
- 为 jax.numpy 排序算法提供更好的测试覆盖率 (#3836)。
jaxlib 0.1.52 (2020 年 7 月 22 日)。
- 更新 XLA。
jax 0.1.73 (2020 年 7 月 22 日)。
- GitHub 提交记录。
- jaxlib 的最低版本现在是 0.1.51。
- 新功能:
- jax.image.resize. (#3703)。
- hfft 和 ihfft (#3664)。
- jax.numpy.intersect1d (#3726)。
- jax.numpy.lexsort (#3812)。
- 当降低到 XLA 时,
lax.scan
和scan
原语支持一个unroll
参数用于循环展开 (#3738)。
- Bug 修复:
- 修复重复轴错误的约简 (#3618)。
- 修复 lax.pad 对输入维度大小为 0 的形状规则错误。 (#3608)。
- 使 psum 转置处理零余切 (#3653)。
- 修复在尺寸为 0 的轴上进行 reduce-prod 的 JVP 的形状错误 (#3729)。
- 支持通过 jax.lax.all_to_all 进行微分。
- 解决了 jax.scipy.special.zeta 中的 nan 问题。(#3777)
- 改进:
- 对 jax2tf 进行了许多改进。
- 重新实现了使用单次变量减少的 argmin/argmax。(#3611)
- 默认启用 XLA SPMD 分区。(#3151)
- 支持 0d 转置卷积。(#3643)
- 使低秩矩阵的 LU 梯度工作。
- 支持 jet 中的多结果和自定义 JVPs。
- 通用化了 reduce-window 的填充,支持(lo, hi)对。(#3728)
- 在 CPU 和 GPU 上实现复杂卷积。(#3735)
- 使 jnp.take 在空数组的空切片上工作。(#3751)
- 放宽了 dot_general 的维度排序规则。(#3778)
- 启用 GPU 的缓冲捐赠。(#3800)
- 为减少窗口操作添加了基本扩张和窗口扩张支持…(#3803)
jaxlib 0.1.51(2020 年 7 月 2 日)
- 更新 XLA。
- 添加了对 host_callback 的新运行时支持。
jax 0.1.72(2020 年 6 月 28 日)
- GitHub 提交记录。
- Bug 修复:
- 修复了前一个版本中引入的 odeint Bug,见 #3587。
jax 0.1.71(2020 年 6 月 25 日)
- GitHub 提交记录。
- 现在的 jaxlib 最低版本要求是 0.1.48。
- Bug 修复:
- 允许
jax.experimental.ode.odeint
动态函数在我们对其进行微分的值上进行闭包 #3562。
jaxlib 0.1.50(2020 年 6 月 25 日)
- 增加了对 CUDA 11.0 的支持。
- 放弃对 CUDA 9.2 的支持(我们只支持最后四个 CUDA 版本)。
- 更新 XLA。
jaxlib 0.1.49(2020 年 6 月 19 日)
- Bug 修复:
- 修复了编译问题,可能导致编译速度慢(tensorflow/tensorflow)。
jaxlib 0.1.48(2020 年 6 月 12 日)
- 新特性:
- 增加了快速回溯收集的支持。
- 增加了对设备堆分析的初步支持。
- 为
bfloat16
类型实现了np.nextafter
。 - CPU 和 GPU 上的 Complex128 支持 FFT。
- Bug 修复:
- 改进了在 GPU 上
tanh
的 float64 精度。 - GPU 上的 float64 散布现在更快了。
- 在 CPU 上的复杂矩阵乘法应该更快了。
- CPU 上的稳定排序现在实际上是稳定的了。
- CPU 后端的并发 Bug 修复。
jax 0.1.70(2020 年 6 月 8 日)
- GitHub 提交记录。
- 新特性:
lax.switch
引入了带有多分支的索引条件,并与cond
原语的泛化一起使用 #3318。
jax 0.1.69(2020 年 6 月 3 日)
jax 0.1.68(2020 年 5 月 21 日)
- GitHub 提交记录。
- 新特性:
lax.cond()
支持单操作数形式,作为两个分支的参数 #2993。
- 注意事项改动:
jax.experimental.host_callback.id_tap()
原语的transforms
关键字格式已更改 #3132。
jax 0.1.67(2020 年 5 月 12 日)
- GitHub 提交记录。
- 新功能:
- 支持使用
axis_index_groups
对 pmapped 轴的子集进行缩减 #2382。 - 实验性支持从编译代码调用和打印主机端 Python 函数。参见 id_print 和 id_tap(#3006)。
- 显著变更:
- 从
jax.numpy
导出的名称的可见性已加强。这可能会破坏之前无意中使用这些名称的代码。
jaxlib 0.1.47(2020 年 5 月 8 日)
- 修复 outfeed 引起的崩溃。
jax 0.1.66(2020 年 5 月 5 日)
- GitHub 提交记录。
- 新功能:
- 支持在
pmap()
上使用in_axes=None
进行缩减 #2896。
jaxlib 0.1.46(2020 年 5 月 5 日)
- 修复 Mac OS X 上线性代数函数的崩溃(#432)。
- 修复使用 AVX512 指令时因操作系统或虚拟化程序禁用而导致的非法指令崩溃问题(#2906)。
jax 0.1.65(2020 年 4 月 30 日)
- GitHub 提交记录。
- 新功能:
- 对奇异矩阵行列式的微分 #2809。
- Bug 修复:
jaxlib 0.1.45(2020 年 4 月 21 日)
- 修复段错误:#2755
- 在 Sort HLO 上通过 Plumb 选项支持稳定性。
jax 0.1.64(2020 年 4 月 21 日)
- GitHub 提交记录。
- 新功能:
- 添加函数式索引更新的语法糖 #2684。
- 添加
jax.numpy.linalg.multi_dot()
#2726。 - 添加
jax.numpy.unique()
#2760。 - 添加
jax.numpy.rint()
#2724。 - 添加
jax.numpy.rint()
#2724。 - 为
jax.experimental.jet()
添加更多原始规则。
- Bug 修复:
- 更好的错误修复:
- 改进
lax.while_loop()
的反向模式微分的错误消息 #2129。
jaxlib 0.1.44(2020 年 4 月 16 日)
- 修复了一个 bug,即当存在多个不同型号的 GPU 时,JAX 只会编译适用于第一个 GPU 的程序。
- 修复了
batch_group_count
卷积的错误。 - 为更多 GPU 版本添加了预编译的 SASS,以避免启动时 PTX 编译挂起。
jax 0.1.63 (2020 年 4 月 12 日)
- GitHub 提交记录。
- 添加了
jax.custom_jvp
和jax.custom_vjp
,来源于 #2026,请参阅教程笔记本。弃用了jax.custom_transforms
并将其从文档中删除(尽管它仍然可用)。 - 添加了
scipy.sparse.linalg.cg
#2566。 - 更改了 Tracers 的打印方式,以显示更多有用的调试信息 #2591。
- 修复了
jax.numpy.isclose
正确处理nan
和inf
的方式 #2501。 - 添加了几个
jax.experimental.jet
的新规则 #2537。 - 当未提供
scale
/center
时,修复了jax.experimental.stax.BatchNorm
。 - 修复了
jax.numpy.einsum
中广播的一些缺失情况 #2512。 - 通过并行前缀扫描实现了
jax.numpy.cumsum
和jax.numpy.cumprod
,并使reduce_prod
对任意阶数可微分 #2596 #2597。 - 在
conv_general_dilated
中添加了batch_group_count
#2635。 - 为
test_util.check_grads
添加了文档字符串 #2656。 - 添加了
callback_transform
#2665。 - 实现了
rollaxis
、convolve
/correlate
的 1 维和 2 维、copysign
、trunc
、roots
以及quantile
/percentile
的插值选项。
jaxlib 0.1.43 (2020 年 3 月 31 日)
- 修复了 GPU 上 Resnet-50 的性能回归问题。
jax 0.1.62 (2020 年 3 月 21 日)
- GitHub 提交记录。
- JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。
- 删除了内部函数
lax._safe_mul
,该函数实现了约定0. * nan == 0.
。此更改意味着在某些程序被微分时会产生 nan,而不是以前产生正确值,尽管这确保了对其他程序产生 nan 而不是静默的不正确结果。详见 #2447 和 #1052。 - 添加了一个
all_gather
并行便利函数。 - 在核心代码中增加了更多类型注解。
jaxlib 0.1.42 (2020 年 3 月 19 日)
- jaxlib 0.1.41 由于 API 不兼容性破坏了云 TPU 支持。此版本修复了这个问题。
- JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。
jax 0.1.61 (2020 年 3 月 17 日)
- GitHub 提交记录。
- 修复 Python 3.5 支持。这将是 JAX 或 jaxlib 版本的最后一个支持 Python 3.5 的版本。
jax 0.1.60(2020 年 3 月 17 日)
- GitHub 提交。
- 新功能:
jax.pmap()
增加了static_broadcast_argnums
参数,该参数允许用户指定应该作为编译时常数处理的参数,并应广播到所有设备。它类似于jax.jit()
中的static_argnums
。- 改善了错误消息,以防止错误地在全局状态中保存跟踪器。
- 添加了
jax.nn.one_hot()
实用函数。 - 添加了
jax.experimental.jet
,用于更快的高阶自动微分。 - 对
jax.lax.broadcast_in_dim()
的参数进行了更多正确性检查。
- 最小 jaxlib 版本现已是 0.1.41。
jaxlib 0.1.40(2020 年 3 月 4 日)
- 添加了 Jaxlib 对 TensorFlow 分析仪的实验性支持,该分析仪允许从 TensorBoard 跟踪 CPU 和 GPU 计算。
- 包括多主机 GPU 计算支持的原型,该计算通过 NCCL 通信。
- 改善了在 GPU 上的 NCCL 集合性能。
- 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
- 支持在 XLA 编译时已知的设备分配。
jax 0.1.59(2020 年 2 月 11 日)
- GitHub 提交。
- 重大更改
- 最小 jaxlib 版本现已是 0.1.38。
- 简化
Jaxpr
,通过删除Jaxpr.freevars
和Jaxpr.bound_subjaxprs
。调用基本功能(xla_call
、xla_pmap
、sharded_call
和remat_call
)获取一个新的参数call_jaxpr
,它具有一个完全闭合(无constvars
)的 jaxpr。此外,还添加了一个新的字段call_primitive
到基本功能。
- 新功能:
- 反向模式自动微分(例如
grad
)对lax.cond
的支持,使其在两种模式下都可微分(#2091) - JAX 现在支持 DLPack,它允许以零副本方式共享 CPU 和 GPU 数组与其他库(例如 PyTorch)。
- JAX GPU DeviceArrays 现在支持
__cuda_array_interface__
,这是另一种用于与 CuPy 和 Numba 等库共享 GPU 数组的零副本协议。 - JAX 的 CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间的零副本缓冲区共享。
- 添加了名为 JAX_SKIP_SLOW_TESTS 的环境变量,以跳过已知为慢的测试。
jaxlib 0.1.39(2020 年 2 月 11 日)
- 更新 XLA。
jaxlib 0.1.38(2020 年 1 月 29 日)
- 不再支持 CUDA 9.0。
- 默认构建 CUDA 10.2 的轮。
jax 0.1.58(2020 年 1 月 28 日)
- [GitHub GitHub 提交。
- 重大更改
- JAX 已弃用对 Python 2 的支持,因为 Python 2 于 2020 年 1 月 1 日达到生命周期结束。请更新到 Python 3.5 或更新版本。
- 新功能
- 正向模式自动微分(
jvp
)对 while 循环的支持(#1980)- 新的 NumPy 和 SciPy 功能:
jax.numpy.fft.fft2() jax.numpy.fft.ifft2() jax.numpy.fft.rfft() jax.numpy.fft.irfft() jax.numpy.fft.rfft2() jax.numpy.fft.irfft2() jax.numpy.fft.rfftn() jax.numpy.fft.irfftn() jax.numpy.fft.fftfreq() jax.numpy.fft.rfftfreq() jax.numpy.linalg.matrix_rank() jax.numpy.linalg.matrix_power() jax.scipy.special.betainc()
- 现在在 GPU 上进行批次 Cholesky 分解时使用了更高效的批次核心。
显著的错误修复
- 使用 Python 3 升级后,JAX 不再依赖于
fastcache
,这应该有助于安装。