JAX 中文文档(十六)(5)

简介: JAX 中文文档(十六)

JAX 中文文档(十六)(4)https://developer.aliyun.com/article/1559730


jaxlib 0.1.69(2021 年 7 月 9 日)

  • 修复了 TFRT CPU 后端中导致结果不正确的错误。

jax 0.2.17(2021 年 7 月 9 日)

  • 对于 jaxlib <= 0.1.68,默认使用较旧的“stream_executor” CPU 运行时,以解决#7229,这导致 CPU 上由于并发问题输出错误结果。
  • 新功能:
  • 新的 SciPy 函数jax.scipy.special.sph_harm()
  • 反向模式自动微分函数(jax.grad()jax.value_and_grad()jax.vjp()jax.linear_transpose())支持一个参数,指示在后向传递中应该对哪些命名轴进行求和,如果它们在前向传递中被广播。这使得可以在 maps 内部以非每个示例的方式使用这些 API(最初仅 jax.experimental.maps.xmap())(#6950)。

jax 0.2.16(2021 年 6 月 23 日)

jax 0.2.15(2021 年 6 月 23 日)

  • #7042 使用了 TFRT CPU 后端,在 CPU 上显著提升了分派性能。
  • jax2tf.convert() 支持布尔型不等式和 min/max 函数(#6956)。
  • 新的 SciPy 函数jax.scipy.special.lpmn_values()
  • Breaking 变更:
  • Bug 修复:
  • 修复了阻止从 JAX 到 TF 再到 JAX 回传的错误:jax2tf.call_tf(jax2tf.convert)#6947)。

jaxlib 0.1.68(2021 年 6 月 23 日)

  • Bug 修复:
  • 修复了 TFRT CPU 后端中将 TPU 缓冲区传输到 CPU 时出现 NaN 的错误。

jax 0.2.14(2021 年 6 月 10 日)

  • jax2tf.convert() 现在支持 pjitsharded_jit
  • 新的配置选项 JAX_TRACEBACK_FILTERING 控制 JAX 如何过滤回溯信息。
  • 在足够新的 IPython 版本中,默认启用了使用 __tracebackhide__ 的新的回溯过滤模式。
  • jax2tf.convert() 在算术操作中使用未知维度时,即使在形状多态性中,也支持形状多态性,例如 jnp.reshape(-1)#6827)。
  • jax2tf.convert() 现在在 TF 操作中生成具有位置信息的自定义属性。在 jax2tf 之后 XLA 生成的代码具有与 JAX/XLA 相同的位置信息。
  • 新的 SciPy 函数 jax.scipy.special.lpmn()
  • Bug fixes:
  • jax2tf.convert() 现在确保对于 Python 标量和选择 32 位 vs. 64 位计算时使用相同的类型规则,如 JAX(#6883)。
  • jax2tf.convert() 现在正确地将 enable_xla 转换参数限定范围到仅在即时转换期间应用(#6720)。
  • jax2tf.convert() 现在使用 XlaDot TensorFlow 操作来转换 lax.dot_general,以提高与 JAX 数值精度的一致性(#6717)。
  • jax2tf.convert() 现在支持复数的不等式比较和最小/最大值(#6892)。

jaxlib 0.1.67(2021 年 5 月 17 日)

jaxlib 0.1.66(2021 年 5 月 11 日)

  • 新特性:
  • 现在支持在所有 CUDA 11 版本(11.1 或更高版本)上使用 CUDA 11.1 wheels。
    NVIDIA 现在承诺从 CUDA 11.1 开始兼容 CUDA 小版本更新。这意味着 JAX 可以发布一个兼容 CUDA 11.2 和 11.3 的单个 CUDA 11.1 wheel。
    不再为 CUDA 11.2(或更高版本)发布单独的 jaxlib 版本;对于这些版本,请使用 CUDA 11.1 wheel(cuda111)。
  • Jaxlib 现在在 CUDA wheels 中捆绑 libdevice.10.bc。不需要指定 CUDA 安装路径来查找此文件。
  • jit() 实现自动支持静态关键字参数。
  • 添加了对预转换异常跟踪的支持。
  • 初步支持从 jit() 转换的计算中剪枝未使用的参数。剪枝仍在进行中。
  • 改进了 PyTreeDef 对象的字符串表示。
  • 添加了对 XLA 可变 ReduceWindow 的支持。
  • Bug fixes:
  • 修复了在远程云 TPU 支持中传递大量参数时的 bug。
  • 修复了一个问题,即 jit() 转换的函数未触发 JAX 垃圾回收。

jax 0.2.13(2021 年 5 月 3 日)

  • 结合 jaxlib 0.1.66 使用时,jax.jit() 现在支持静态关键字参数。新增了 static_argnames 选项以指定关键字参数为静态。
  • jax.nonzero() 现在有一个新的可选参数 size,允许在 jit 内使用 (#6501)。
  • jax.numpy.unique() 现在支持 axis 参数 (#6532)。
  • jax.experimental.host_callback.call() 现在支持 pjit.pjit (#6569)。
  • 添加了 jax.scipy.linalg.eigh_tridiagonal(),用于计算三对角矩阵的特征值。目前仅支持特征值。
  • 异常中筛选和未筛选的堆栈跟踪顺序已更改。从 JAX 转换代码中抛出的异常现在附带有过滤后的回溯,UnfilteredStackTrace 异常包含原始跟踪作为过滤异常的 __cause__。现在,筛选的堆栈跟踪也适用于 Python 3.6。
  • 如果由反向模式自动微分转换的代码引发异常,JAX 现在尝试附加一个 JaxStackTraceBeforeTransformation 对象作为异常的 __cause__,该对象包含在正向传递中创建原始操作的堆栈跟踪。需要 jaxlib 0.1.66。
  • 破坏性变更:
  • 下列函数名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
  • host_id –> process_index()
  • host_count –> process_count()
  • host_ids –> range(jax.process_count())
  • 同样地,local_devices() 的参数已从 host_id 重命名为 process_index
  • 除了函数之外的 jax.jit() 参数现在标记为仅限关键字。此更改旨在防止在向 jit 添加参数时意外破坏代码。
  • Bug 修复:
  • 现在 jax2tf.convert() 在带有整数输入的函数梯度存在时能正常工作 (#6360)。
  • 修复了 jax2tf.call_tf() 在与捕获的 tf.Variable 结合使用时的断言失败 (#6572)。

jaxlib 0.1.65(2021 年 4 月 7 日)

jax 0.2.12(2021 年 4 月 1 日)

  • 新的分析 API:jax.profiler.start_trace()jax.profiler.stop_trace()jax.profiler.trace()
  • jax.lax.reduce() 现在可微分。
  • 破坏性变更:
  • 最低的 jaxlib 版本现在是 0.1.64。
  • 一些分析器 API 名称已更改。仍然存在别名,因此不应该破坏现有代码,但别名最终将被移除,请更改您的代码。
  • TraceContext –> TraceAnnotation()
  • StepTraceContext –> StepTraceAnnotation()
  • trace_function –> annotate_function()
  • 无法禁用全局分析。有关更多信息,请参阅 omnistaging
  • Python 整数大于最大的int64值现在在所有情况下都会导致溢出,而不是在某些情况下静默转换为uint64#6047)。
  • 在非 X64 模式下,超出int32可表示范围的 Python 整数现在将导致OverflowError,而不是静默截断其值。
  • Bug 修复:
  • host_callback现在支持参数和结果中的空数组(#6262)。
  • jax.random.randint()在超出限制范围时会剪切而不是包裹,现在可以生成指定 dtype 的整数的全部范围(#5868)。

jax 0.2.11(2021 年 3 月 23 日)

  • #6112 添加了上下文管理器:jax.enable_checksjax.check_tracer_leaksjax.debug_nansjax.debug_infsjax.log_compiles
  • #6085 添加了jnp.delete
  • Bug 修复:
  • #6136 泛化了jax.flatten_util.ravel_pytree以处理整数 dtype。
  • #6129 修复了处理像enum.IntEnums这样的一些常量的错误
  • #6145 修复了不完全贝塔函数批处理问题
  • #6014 修复了追踪过程中的 H2D 传输问题
  • #6165 在将一些大的 Python 整数转换为浮点数时避免 OverflowErrors
  • 破坏性变更:
  • jaxlib 最小版本现在是 0.1.62。

jaxlib 0.1.64(2021 年 3 月 18 日)

jaxlib 0.1.63(2021 年 3 月 17 日)

jax 0.2.10(2021 年 3 月 5 日)

  • jax.scipy.stats.chi2()现在作为具有 logpdf 和 pdf 方法的分布可用。
  • jax.scipy.stats.betabinom()现在作为具有 logpmf 和 pmf 方法的分布可用。
  • 添加了jax.experimental.jax2tf.call_tf()以从 JAX 调用 TensorFlow 函数(#5627)和README
  • 扩展了lax.pad的批处理规则以支持填充值的批处理。
  • Bug 修复:
  • jax.numpy.take()正确处理负索引(#5768
  • 破坏性变更:
  • 调整了 JAX 的提升规则,使提升更一致且不受 JIT 影响。特别是,当适当时,二进制操作现在可以产生弱类型值。更改的主要用户可见效果是某些操作的输出精度与之前不同;例如表达式 jnp.bfloat16(1) + 0.1 * jnp.arange(10) 以前返回 float64 数组,现在返回 bfloat16 数组。JAX 的类型提升行为在类型提升语义中描述。
  • jax.numpy.linspace() 现在计算整数值的地板,即向负无穷取整,而不是向 0 取整。此更改是为了与 NumPy 1.20.0 保持一致。
  • jax.numpy.i0() 不再接受复数。之前该函数计算复数参数的绝对值。此更改是为了与 NumPy 1.20.0 的语义保持一致。
  • 几个 jax.numpy 函数不再接受元组或列表作为数组参数的替代:jax.numpy.pad()jax.numpy.raveljax.numpy.repeat()jax.numpy.reshape()。通常情况下,应使用标量或数组参数调用 jax.numpy 函数。

jaxlib 0.1.62 (2021 年 3 月 9 日)

  • 新特性:
  • 在 x86-64 机器上,默认情况下构建 jaxlib wheels 需要 AVX 指令。如果要在不支持 AVX 的机器上使用 JAX,可以使用 build.py--target_cpu_features 标志从源代码构建 jaxlib。 --target_cpu_features 还替换了 --enable_march_native

jaxlib 0.1.61 (2021 年 2 月 12 日)

jaxlib 0.1.60 (2021 年 2 月 3 日)

  • 错误修复:
  • 修复了将 CPU DeviceArrays 转换为 NumPy 数组时的内存泄漏问题。在 jaxlib 发布的 0.1.58 和 0.1.59 版本中存在该内存泄漏。
  • boolint8uint8 现在被认为是安全的,可以转换为 bfloat16 NumPy 扩展类型。

jax 0.2.9 (2021 年 1 月 26 日)

  • 扩展 jax.experimental.loops 模块以支持 pytrees。改进了错误检查和错误消息。
  • 添加 jax.experimental.enable_x64()jax.experimental.disable_x64()。这些是上下文管理器,允许在会话中临时启用/禁用 X64 模式。
  • 破坏性变更:
  • jax.ops.segment_sum() 现在在性能考虑下删除超出范围的段 ID,而不是将它们包装到段 ID 空间。

jaxlib 0.1.59 (2021 年 1 月 15 日)

jax 0.2.8 (2021 年 1 月 12 日)

  • 添加 jax.closure_convert() 用于与高阶自定义导数函数一起使用。 (#5244)
  • 添加 jax.experimental.host_callback.call() 以调用主机上的自定义 Python 函数并将结果返回到设备计算中。 (#5243)
  • 错误修复:
  • jax.numpy.arccosh 现在对复数输入返回与 numpy.arccosh 相同的分支(#5156)。
  • 现在 host_callback.id_tapjax.pmap 中也可以使用。对于 id_tapid_print,现在有一个可选参数,可以请求将值从中提取的设备作为关键字参数传递给 tap 函数(#5182)。
  • 破坏性更改:
  • jax.numpy.pad 现在接受关键字参数。位置参数 constant_values 已被移除。此外,传递不受支持的关键字参数将引发错误。
  • jax.experimental.host_callback.id_tap() 的更改(#5243):
  • 删除了对 jax.experimental.host_callback.id_tap()kwargs 支持(这种支持已经被弃用几个月了)。
  • 更改了 jax.experimental.host_callback.id_print() 中元组的打印方式,使用了 而不是 ‘‘
  • jax.experimental.host_callback.id_print() 存在 JVP 的情况下,更改了打印元组的方式,现在使用了一对主元和切线。以前是分别打印主元和切线。
  • 删除了 host_callback.outfeed_receiver(这不再需要,并且几个月前已被弃用)。
  • 新功能:
  • inf 的调试添加了一个新标志,类似于 NaN 的标志(#5224)。

jax 0.2.7(2020 年 12 月 4 日)

  • 添加了 jax.device_put_replicated
  • jax.experimental.sharded_jit 添加了多主机支持。
  • 增加对 jax.numpy.linalg.eig 计算的特征值的微分支持。
  • 增加了对在 Windows 平台上构建的支持。
  • jax.pmap 中添加了对通用 in_axesout_axes 的支持。
  • 添加了对 jax.numpy.linalg.slogdet 的复数支持。
  • Bug 修复:
  • 修复 jax.numpy.sinc 在零点处高于二阶导数的问题。
  • 修复了在转置规则中的符号零的一些难以命中的 bug。
  • 破坏性更改:
  • 已删除 jax.experimental.optix,改为独立的 optax Python 包。
  • 使用非元组序列索引 JAX 数组现在会引发 TypeError。这种类型的索引自从 Numpy v1.16 和 JAX v0.2.4 开始已经被弃用。参见 #4564

jax 0.2.6(2020 年 11 月 18 日)

  • jax.experimental.jax2tf 转换器的形状多态跟踪添加了支持。参见 README.md
  • 破坏性更改清理:
  • 对于 jax.jitxla_computation 中的非可哈希静态参数,现在会引发错误。参见 cb48f42
  • 改善了类型提升行为的一致性(#4744):
  • 将复杂的 Python 标量添加到 JAX 浮点数会保留 JAX 浮点数的精度。例如,jnp.float32(1) + 1j 现在返回 complex64,而之前返回的是 complex128
  • 当涉及到包含 uint64、有符号整型和第三种类型的三个或更多术语的类型提升时,现在与参数顺序无关。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) 都返回 float16,之前第一个返回 float64,第二个返回 float16
  • (未记录的) jax.lax_linalg 线性代数模块现在公开为 jax.lax.linalg
  • jax.random.PRNGKey 现在在 JIT 编译内外产生相同的结果 (#4877)。这需要在几个特定情况下更改给定种子的结果:
  • 使用 jax_enable_x64=False 时,作为 Python 整数传递的负数种子现在在 JIT 模式外返回不同的结果。例如,jax.random.PRNGKey(-1) 以前返回 [4294967295, 4294967295],现在返回 [0, 4294967295]。这与 JIT 中的行为一致。
  • JIT 外部的 int64 不能表示的范围外的种子现在会导致 OverflowError 而不是 TypeError。这与 JIT 中的行为一致。
  • 要恢复在 jax_enable_x64=False 时以前针对负整数返回的键,可以使用:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF) 
  • 当尝试访问已删除其值的 DeviceArray 时,现在会引发 RuntimeError 而不是 ValueError

jaxlib 0.1.58 (2021 年 1 月 12 日)

  • 修复了 JAX 有时返回平台特定类型(如 np.cint)而不是标准类型(如 np.int32)的 Bug (#4903)。
  • 修复了在执行某些 int16 操作时常量折叠导致崩溃的问题 (#4971)。
  • pytree.flatten() 中添加了一个 is_leaf 谓词。

jaxlib 0.1.57 (2020 年 11 月 12 日)

  • 修复了 GPU wheels 中的 manylinux2010 兼容性问题。
  • 将 CPU FFT 实现从 Eigen 切换到 PocketFFT。
  • 修复了 bfloat16 值哈希未正确初始化并可能更改的 Bug (#4651)。
  • 添加了对将数组传递给 DLPack 时保留所有权的支持 (#4636)。
  • 修复了批量三角求解的一个 Bug,对大于 128 但不是 128 的倍数的情况。
  • 修复了在多个 GPU 上同时进行并发 FFT 时的 Bug (#3518)。
  • 在分析器中修复了工具缺失的 Bug (#4427)。
  • 放弃了对 CUDA 10.0 的支持。

jax 0.2.5 (2020 年 10 月 27 日)

jax 0.2.4 (2020 年 10 月 19 日)

  • jax.experimental.host_callback 添加了对 remat 的支持。参见 #4608
  • 弃用
  • 现在,使用非元组序列进行索引已被弃用,遵循 Numpy 中的类似弃用。在将来的版本中,这将导致 TypeError。参见 #4564

jaxlib 0.1.56 (2020 年 10 月 14 日)。

jax 0.2.3 (2020 年 10 月 14 日)。

  • GitHub 提交记录
  • 由于需要暂时回退新的 jit 快速通路,因此又进行了一个新的发布。

jax 0.2.2 (2020 年 10 月 13 日)。

jax 0.2.1 (2020 年 10 月 6 日)。

  • 作为全阶段的一个好处,即使 jax.experimental.host_callback.id_print() / jax.experimental.host_callback.id_tap() 的结果未在计算中使用,也会按程序顺序执行 host_callback 函数。

jax (0.2.0) (2020 年 9 月 23 日)。

jax (0.1.77) (2020 年 9 月 15 日)。

  • 破坏性变更:
  • jax.experimental.host_callback.id_tap() 的新简化接口 (#4101)。

jaxlib 0.1.55 (2020 年 9 月 8 日)。

  • 更新 XLA:
  • 修复 DLPackManagedTensorToBuffer 中的错误 (#4196)。

jax 0.1.76 (2020 年 9 月 8 日)。

jax 0.1.75 (2020 年 7 月 30 日)。

  • 使 jnp.abs() 适用于无符号输入 (#3914)。
  • 改进:
  • 添加了“全阶段”行为,但在默认情况下已禁用 (#3370)。

jax 0.1.74 (2020 年 7 月 29 日)。

  • BFGS (#3101)。
  • TPU 支持半精度算术 (#3878)。
  • Bug 修复:
  • 防止一些意外的 dtype 警告 (#3874)。
  • 修复自定义导数中的多线程错误 (#3845, #3869)。
  • 改进:
  • 更快的 searchsorted 实现 (#3873)。
  • 为 jax.numpy 排序算法提供更好的测试覆盖率 (#3836)。

jaxlib 0.1.52 (2020 年 7 月 22 日)。

  • 更新 XLA。

jax 0.1.73 (2020 年 7 月 22 日)。

  • jax.image.resize. (#3703)。
  • hfft 和 ihfft (#3664)。
  • jax.numpy.intersect1d (#3726)。
  • jax.numpy.lexsort (#3812)。
  • 当降低到 XLA 时,lax.scanscan 原语支持一个 unroll 参数用于循环展开 (#3738)。
  • Bug 修复:
  • 修复重复轴错误的约简 (#3618)。
  • 修复 lax.pad 对输入维度大小为 0 的形状规则错误。 (#3608)。
  • 使 psum 转置处理零余切 (#3653)。
  • 修复在尺寸为 0 的轴上进行 reduce-prod 的 JVP 的形状错误 (#3729)。
  • 支持通过 jax.lax.all_to_all 进行微分。
  • 解决了 jax.scipy.special.zeta 中的 nan 问题。(#3777)
  • 改进:
  • 对 jax2tf 进行了许多改进。
  • 重新实现了使用单次变量减少的 argmin/argmax。(#3611)
  • 默认启用 XLA SPMD 分区。(#3151)
  • 支持 0d 转置卷积。(#3643)
  • 使低秩矩阵的 LU 梯度工作。
  • 支持 jet 中的多结果和自定义 JVPs。
  • 通用化了 reduce-window 的填充,支持(lo, hi)对。(#3728)
  • 在 CPU 和 GPU 上实现复杂卷积。(#3735)
  • 使 jnp.take 在空数组的空切片上工作。(#3751)
  • 放宽了 dot_general 的维度排序规则。(#3778)
  • 启用 GPU 的缓冲捐赠。(#3800)
  • 为减少窗口操作添加了基本扩张和窗口扩张支持…(#3803)

jaxlib 0.1.51(2020 年 7 月 2 日)

  • 更新 XLA。
  • 添加了对 host_callback 的新运行时支持。

jax 0.1.72(2020 年 6 月 28 日)

  • 修复了前一个版本中引入的 odeint Bug,见 #3587

jax 0.1.71(2020 年 6 月 25 日)

  • 允许 jax.experimental.ode.odeint 动态函数在我们对其进行微分的值上进行闭包 #3562

jaxlib 0.1.50(2020 年 6 月 25 日)

  • 增加了对 CUDA 11.0 的支持。
  • 放弃对 CUDA 9.2 的支持(我们只支持最后四个 CUDA 版本)。
  • 更新 XLA。

jaxlib 0.1.49(2020 年 6 月 19 日)

  • Bug 修复:

jaxlib 0.1.48(2020 年 6 月 12 日)

  • 新特性:
  • 增加了快速回溯收集的支持。
  • 增加了对设备堆分析的初步支持。
  • bfloat16 类型实现了 np.nextafter
  • CPU 和 GPU 上的 Complex128 支持 FFT。
  • Bug 修复:
  • 改进了在 GPU 上 tanh 的 float64 精度。
  • GPU 上的 float64 散布现在更快了。
  • 在 CPU 上的复杂矩阵乘法应该更快了。
  • CPU 上的稳定排序现在实际上是稳定的了。
  • CPU 后端的并发 Bug 修复。

jax 0.1.70(2020 年 6 月 8 日)

  • lax.switch 引入了带有多分支的索引条件,并与 cond 原语的泛化一起使用 #3318

jax 0.1.69(2020 年 6 月 3 日)

jax 0.1.68(2020 年 5 月 21 日)

  • lax.cond() 支持单操作数形式,作为两个分支的参数 #2993
  • 注意事项改动:
  • jax.experimental.host_callback.id_tap() 原语的 transforms 关键字格式已更改 #3132

jax 0.1.67(2020 年 5 月 12 日)

  • 支持使用 axis_index_groups 对 pmapped 轴的子集进行缩减 #2382
  • 实验性支持从编译代码调用和打印主机端 Python 函数。参见 id_print 和 id_tap#3006)。
  • 显著变更:
  • jax.numpy 导出的名称的可见性已加强。这可能会破坏之前无意中使用这些名称的代码。

jaxlib 0.1.47(2020 年 5 月 8 日)

  • 修复 outfeed 引起的崩溃。

jax 0.1.66(2020 年 5 月 5 日)

  • 支持在 pmap() 上使用 in_axes=None 进行缩减 #2896

jaxlib 0.1.46(2020 年 5 月 5 日)

  • 修复 Mac OS X 上线性代数函数的崩溃(#432)。
  • 修复使用 AVX512 指令时因操作系统或虚拟化程序禁用而导致的非法指令崩溃问题(#2906)。

jax 0.1.65(2020 年 4 月 30 日)

  • 对奇异矩阵行列式的微分 #2809
  • Bug 修复:
  • 修复 odeint() 对于具有时间依赖动态的常微分方程的时间微分问题 #2817,并添加 ODE CI 测试。
  • 修复 lax_linalg.qr() 的微分问题 #2867

jaxlib 0.1.45(2020 年 4 月 21 日)

  • 修复段错误:#2755
  • 在 Sort HLO 上通过 Plumb 选项支持稳定性。

jax 0.1.64(2020 年 4 月 21 日)

  • 添加函数式索引更新的语法糖 #2684
  • 添加 jax.numpy.linalg.multi_dot() #2726
  • 添加 jax.numpy.unique() #2760
  • 添加 jax.numpy.rint() #2724
  • 添加 jax.numpy.rint() #2724
  • jax.experimental.jet() 添加更多原始规则。
  • Bug 修复:
  • 修复 logaddexp()logaddexp2() 在零处的微分问题 #2107
  • 在没有 jit() 的情况下改进反向模式自动微分的内存使用情况 #2719
  • 更好的错误修复:
  • 改进 lax.while_loop() 的反向模式微分的错误消息 #2129

jaxlib 0.1.44(2020 年 4 月 16 日)

  • 修复了一个 bug,即当存在多个不同型号的 GPU 时,JAX 只会编译适用于第一个 GPU 的程序。
  • 修复了batch_group_count卷积的错误。
  • 为更多 GPU 版本添加了预编译的 SASS,以避免启动时 PTX 编译挂起。

jax 0.1.63 (2020 年 4 月 12 日)

  • GitHub 提交记录
  • 添加了jax.custom_jvpjax.custom_vjp,来源于 #2026,请参阅教程笔记本。弃用了jax.custom_transforms并将其从文档中删除(尽管它仍然可用)。
  • 添加了scipy.sparse.linalg.cg #2566
  • 更改了 Tracers 的打印方式,以显示更多有用的调试信息 #2591
  • 修复了jax.numpy.isclose正确处理naninf的方式 #2501
  • 添加了几个jax.experimental.jet的新规则 #2537
  • 当未提供scale/center时,修复了jax.experimental.stax.BatchNorm
  • 修复了jax.numpy.einsum中广播的一些缺失情况 #2512
  • 通过并行前缀扫描实现了jax.numpy.cumsumjax.numpy.cumprod,并使reduce_prod对任意阶数可微分 #2596 #2597
  • conv_general_dilated中添加了batch_group_count #2635
  • test_util.check_grads添加了文档字符串 #2656
  • 添加了callback_transform #2665
  • 实现了rollaxisconvolve/correlate的 1 维和 2 维、copysigntruncroots以及quantile/percentile的插值选项。

jaxlib 0.1.43 (2020 年 3 月 31 日)

  • 修复了 GPU 上 Resnet-50 的性能回归问题。

jax 0.1.62 (2020 年 3 月 21 日)

  • GitHub 提交记录
  • JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。
  • 删除了内部函数lax._safe_mul,该函数实现了约定0. * nan == 0.。此更改意味着在某些程序被微分时会产生 nan,而不是以前产生正确值,尽管这确保了对其他程序产生 nan 而不是静默的不正确结果。详见 #2447 和 #1052。
  • 添加了一个all_gather并行便利函数。
  • 在核心代码中增加了更多类型注解。

jaxlib 0.1.42 (2020 年 3 月 19 日)

  • jaxlib 0.1.41 由于 API 不兼容性破坏了云 TPU 支持。此版本修复了这个问题。
  • JAX 已停止支持 Python 3.5。请升级到 Python 3.6 或更新版本。

jax 0.1.61 (2020 年 3 月 17 日)

  • GitHub 提交记录
  • 修复 Python 3.5 支持。这将是 JAX 或 jaxlib 版本的最后一个支持 Python 3.5 的版本。

jax 0.1.60(2020 年 3 月 17 日)

  • jax.pmap() 增加了 static_broadcast_argnums 参数,该参数允许用户指定应该作为编译时常数处理的参数,并应广播到所有设备。它类似于 jax.jit() 中的 static_argnums
  • 改善了错误消息,以防止错误地在全局状态中保存跟踪器。
  • 添加了 jax.nn.one_hot() 实用函数。
  • 添加了 jax.experimental.jet,用于更快的高阶自动微分。
  • jax.lax.broadcast_in_dim() 的参数进行了更多正确性检查。
  • 最小 jaxlib 版本现已是 0.1.41。

jaxlib 0.1.40(2020 年 3 月 4 日)

  • 添加了 Jaxlib 对 TensorFlow 分析仪的实验性支持,该分析仪允许从 TensorBoard 跟踪 CPU 和 GPU 计算。
  • 包括多主机 GPU 计算支持的原型,该计算通过 NCCL 通信。
  • 改善了在 GPU 上的 NCCL 集合性能。
  • 添加了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 实现。
  • 支持在 XLA 编译时已知的设备分配。

jax 0.1.59(2020 年 2 月 11 日)

  • 最小 jaxlib 版本现已是 0.1.38。
  • 简化 Jaxpr,通过删除 Jaxpr.freevarsJaxpr.bound_subjaxprs。调用基本功能(xla_callxla_pmapsharded_callremat_call)获取一个新的参数 call_jaxpr,它具有一个完全闭合(无 constvars)的 jaxpr。此外,还添加了一个新的字段 call_primitive 到基本功能。
  • 新功能:
  • 反向模式自动微分(例如 grad)对 lax.cond 的支持,使其在两种模式下都可微分(#2091
  • JAX 现在支持 DLPack,它允许以零副本方式共享 CPU 和 GPU 数组与其他库(例如 PyTorch)。
  • JAX GPU DeviceArrays 现在支持 __cuda_array_interface__,这是另一种用于与 CuPy 和 Numba 等库共享 GPU 数组的零副本协议。
  • JAX 的 CPU 设备缓冲区现在实现了 Python 缓冲区协议,这允许 JAX 和 NumPy 之间的零副本缓冲区共享。
  • 添加了名为 JAX_SKIP_SLOW_TESTS 的环境变量,以跳过已知为慢的测试。

jaxlib 0.1.39(2020 年 2 月 11 日)

  • 更新 XLA。

jaxlib 0.1.38(2020 年 1 月 29 日)

  • 不再支持 CUDA 9.0。
  • 默认构建 CUDA 10.2 的轮。

jax 0.1.58(2020 年 1 月 28 日)

  • JAX 已弃用对 Python 2 的支持,因为 Python 2 于 2020 年 1 月 1 日达到生命周期结束。请更新到 Python 3.5 或更新版本。
  • 新功能
  • 正向模式自动微分(jvp)对 while 循环的支持(#1980

  • 新的 NumPy 和 SciPy 功能:
jax.numpy.fft.fft2()
jax.numpy.fft.ifft2()
jax.numpy.fft.rfft()
jax.numpy.fft.irfft()
jax.numpy.fft.rfft2()
jax.numpy.fft.irfft2()
jax.numpy.fft.rfftn()
jax.numpy.fft.irfftn()
jax.numpy.fft.fftfreq()
jax.numpy.fft.rfftfreq()
jax.numpy.linalg.matrix_rank()
jax.numpy.linalg.matrix_power()
jax.scipy.special.betainc()


  • 现在在 GPU 上进行批次 Cholesky 分解时使用了更高效的批次核心。

显著的错误修复

  • 使用 Python 3 升级后,JAX 不再依赖于fastcache,这应该有助于安装。
相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
3月前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
28 2
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
34 3
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
77 1
|
3月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
30 1
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
40 1
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(3)
JAX 中文文档(十六)
55 0
|
3月前
|
安全 API 网络架构
JAX 中文文档(十五)(1)
JAX 中文文档(十五)
37 0
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(2)
JAX 中文文档(十五)
23 0