JAX 中文文档(一)(1)

简介: JAX 中文文档(一)


原文:jax.readthedocs.io/en/latest/

开始入门

安装 JAX

原文:jax.readthedocs.io/en/latest/installation.html

使用 JAX 需要安装两个包:jax 是纯 Python 的跨平台库,jaxlib 包含编译的二进制文件,对于不同的操作系统和加速器需要不同的构建。

TL;DR 对于大多数用户来说,典型的 JAX 安装可能如下所示:

  • 仅限 CPU(Linux/macOS/Windows)
pip install -U jax 
  • GPU(NVIDIA,CUDA 12)
pip install -U "jax[cuda12]" 
  • TPU(Google Cloud TPU VM)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 

支持的平台

下表显示了所有支持的平台和安装选项。检查您的设置是否受支持;如果显示“是”或“实验性”,请单击相应链接以了解更详细的 JAX 安装方法。

Linux,x86_64 Linux,aarch64 macOS,Intel x86_64,AMD GPU macOS,Apple Silicon,基于 ARM Windows,x86_64 Windows WSL2,x86_64
CPU
NVIDIA GPU 不适用 实验性
Google Cloud TPU 不适用 不适用 不适用 不适用 不适用
AMD GPU 实验性 不适用

| Apple GPU | 不适用 | 否 | 实验性 | 实验性 | 不适用 | 不适用 | ## CPU

pip 安装:CPU

目前,JAX 团队为以下操作系统和架构发布 jaxlib 轮子:

  • Linux,x86_64
  • Linux, aarch64
  • macOS,Intel
  • macOS,基于 Apple ARM
  • Windows,x86_64(实验性

要安装仅 CPU 版本的 JAX,可能对于在笔记本电脑上进行本地开发非常有用,您可以运行:

pip  install  --upgrade  pip
pip  install  --upgrade  jax 

在 Windows 上,如果尚未安装 Microsoft Visual Studio 2019 Redistributable,您可能还需要安装它。

其他操作系统和架构需要从源代码构建。在其他操作系统和架构上尝试 pip 安装可能导致 jaxlib 未能与 jax 一起安装(虽然 jax 可能成功安装,但在运行时可能会失败)。 ## NVIDIA GPU

JAX 支持具有 SM 版本 5.2(Maxwell)或更新版本的 NVIDIA GPU。请注意,由于 NVIDIA 在其软件中停止了对 Kepler 系列 GPU 的支持,JAX 不再支持 Kepler 系列 GPU。

您必须先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但驱动版本必须 >= 525.60.13 才能在 Linux 上运行 CUDA 12。

如果您需要在较老的驱动程序上使用更新的 CUDA 工具包,例如在无法轻松更新 NVIDIA 驱动程序的集群上,您可以使用 NVIDIA 专门为此目的提供的 CUDA 向前兼容包

pip 安装:NVIDIA GPU(通过 pip 安装,更加简便)

有两种安装 JAX 并支持 NVIDIA GPU 的方式:

  • 使用从 pip 轮子安装的 NVIDIA CUDA 和 cuDNN
  • 使用自行安装的 CUDA/cuDNN

JAX 团队强烈建议使用 pip wheel 安装 CUDA 和 cuDNN,因为这样更加简单!

NVIDIA 仅为 x86_64 和 aarch64 平台发布了 CUDA pip 包;在其他平台上,您必须使用本地安装的 CUDA。

pip  install  --upgrade  pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip  install  --upgrade  "jax[cuda12]" 

如果 JAX 检测到错误版本的 NVIDIA CUDA 库,您需要检查以下几点:

  • 请确保未设置 LD_LIBRARY_PATH,因为 LD_LIBRARY_PATH 可能会覆盖 NVIDIA CUDA 库。
  • 确保安装的 NVIDIA CUDA 库与 JAX 请求的库相符。重新运行上述安装命令应该可以解决问题。

pip 安装:NVIDIA GPU(本地安装的 CUDA,更为复杂)

如果您想使用预安装的 NVIDIA CUDA 副本,您必须首先安装 NVIDIA 的 CUDA cuDNN

JAX 仅为 Linux x86_64 和 Linux aarch64 提供预编译的 CUDA 兼容 wheel。其他操作系统和架构的组合也可能存在,但需要从源代码构建(请参考构建指南以了解更多信息)。

您应该使用至少与您的NVIDIA CUDA toolkit 对应的驱动版本相同的 NVIDIA 驱动程序版本。例如,在无法轻易更新 NVIDIA 驱动程序的集群上需要使用更新的 CUDA 工具包,您可以使用 NVIDIA 为此目的提供的CUDA 向前兼容包

JAX 目前提供一种 CUDA wheel 变体:

Built with Compatible with
CUDA 12.3 CUDA >=12.1
CUDNN 9.0 CUDNN >=9.0, <10.0
NCCL 2.19 NCCL >=2.18

JAX 检查您的库的版本,如果版本不够新,则会报错。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK 环境变量将禁用此检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。

NCCL 是一个可选依赖项,仅在执行多 GPU 计算时才需要。

安装方法如下:

pip  install  --upgrade  pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip  install  --upgrade  "jax[cuda12_local]" 

这些 pip 安装在 Windows 上无法工作,并可能静默失败;请参考上表。

您可以使用以下命令查找您的 CUDA 版本:

nvcc  --version 

JAX 使用 LD_LIBRARY_PATH 查找 CUDA 库,并使用 PATH 查找二进制文件(ptxasnvlink)。请确保这些路径指向正确的 CUDA 安装位置。

如果在使用预编译的 wheel 时遇到任何错误或问题,请在GitHub 问题跟踪器上告知 JAX 团队。

NVIDIA GPU Docker 容器

NVIDIA 提供了JAX 工具箱容器,这些是 bleeding edge 容器,包含 jax 的夜间版本和一些模型/框架。 ## Google Cloud TPU

pip 安装:Google Cloud TPU

JAX 为 Google Cloud TPU 提供预构建的安装包。要在云 TPU VM 中安装 JAX 及相应版本的 jaxliblibtpu,您可以运行以下命令:

pip  install  jax[tpu]  -f  https://storage.googleapis.com/jax-releases/libtpu_releases.html 

对于 Colab 的用户(https://colab.research.google.com/),请确保您使用的是 TPU v2 而不是已过时的旧 TPU 运行时。## Apple Silicon GPU(基于 ARM 的)

pip 安装:Apple 基于 ARM 的 Silicon GPU

Apple 为基于 ARM 的 GPU 硬件提供了一个实验性的 Metal 插件。详情请参阅 Apple 的 JAX on Metal 文档

注意: Metal 插件存在一些注意事项:

  • Metal 插件是新的实验性质,并存在一些已知问题,请在 JAX 问题跟踪器上报告任何问题。
  • 当前的 Metal 插件需要非常特定版本的 jaxjaxlib。随着插件 API 的成熟,此限制将逐步放宽。## AMD GPU

JAX 具有实验性的 ROCm 支持。有两种安装 JAX 的方法:

  • 使用 AMD 的 Docker 容器;或者
  • 从源代码构建(参见从源代码构建 —— 一个名为 Additional notes for building a ROCM jaxlib for AMD GPUs 的部分)。

Conda(社区支持)

Conda 安装

存在一个社区支持的 jax 的 Conda 构建。要使用 conda 安装它,只需运行:

conda  install  jax  -c  conda-forge 

要在带有 NVIDIA GPU 的机器上安装它,请运行:

conda  install  jaxlib=*=*cuda*  jax  cuda-nvcc  -c  conda-forge  -c  nvidia 

请注意,由 conda-forge 分发的 cudatoolkit 缺少 JAX 所需的 ptxas。因此,您必须从 nvidia 渠道安装 cuda-nvcc 包,或者在您的机器上单独安装 CUDA,以便 ptxas 在您的路径中可用。上述渠道顺序很重要(conda-forgenvidia 之前)。

如果您希望覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 版本,请按照 conda-forge 网站上“技巧和技巧”部分的说明操作。

前往 conda-forgejaxlibjax 存储库获取更多详细信息。

JAX 夜间安装

夜间版本反映了它们构建时主 JAX 存储库的状态,并且可能无法通过完整的测试套件。

  • 仅限 CPU:
pip  install  -U  --pre  jax  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html 
  • Google Cloud TPU:
pip  install  -U  --pre  jax[tpu]  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html  -f  https://storage.googleapis.com/jax-releases/libtpu_releases.html 
  • NVIDIA GPU(CUDA 12):
pip  install  -U  --pre  jax[cuda12]  -f  https://storage.googleapis.com/jax-releases/jax_nightly_releases.html 
  • NVIDIA GPU(CUDA 12)遗留:

用于历史 nightly 版本的单片 CUDA jaxlibs。您很可能不需要此选项;不会再构建更多的单片 CUDA jaxlibs,并且现有的将在 2024 年 9 月到期。请使用上面的“CUDA 12”选项。

pip  install  -U  --pre  jaxlib  -f  https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html 

从源代码构建 JAX

参考从源代码构建。

安装旧版本的 jaxlib wheels

由于 Python 软件包索引上的存储限制,JAX 团队定期从 http://pypi.org/project/jax 的发布中删除旧的jaxlib安装包。但是您仍然可以通过这里的 URL 直接安装它们。例如:

# Install jaxlib on CPU via the wheel archive
pip  install  jax[cpu]==0.3.25  -f  https://storage.googleapis.com/jax-releases/jax_releases.html
# Install the jaxlib 0.3.25 CPU wheel directly
pip  install  jaxlib==0.3.25  -f  https://storage.googleapis.com/jax-releases/jax_releases.html 

对于特定的旧 GPU 安装包,请确保使用jax_cuda_releases.html的 URL;例如

pip  install  jaxlib==0.3.25+cuda11.cudnn82  -f  https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 

快速入门

原文:jax.readthedocs.io/en/latest/quickstart.html

JAX 是一个面向数组的数值计算库(à la NumPy),具有自动微分和 JIT 编译功能,以支持高性能的机器学习研究

本文档提供了 JAX 主要功能的快速概述,让您可以快速开始使用 JAX:

  • JAX 提供了一个统一的类似于 NumPy 的接口,用于在 CPU、GPU 或 TPU 上运行的计算,在本地或分布式设置中。
  • JAX 通过 Open XLA 内置了即时编译(JIT)功能,这是一个开源的机器学习编译器生态系统。
  • JAX 函数支持通过其自动微分转换有效地评估梯度。
  • JAX 函数可以自动向量化,以有效地将它们映射到表示输入批次的数组上。

安装

可以直接从 Python Package Index 安装 JAX 用于 Linux、Windows 和 macOS 上的 CPU:

pip install jax 

或者,对于 NVIDIA GPU:

pip install -U "jax[cuda12]" 

如需更详细的特定平台安装信息,请查看安装 JAX。

JAX 就像 NumPy 一样

大多数 JAX 的使用是通过熟悉的 jax.numpy API 进行的,通常在 jnp 别名下导入:

import jax.numpy as jnp 

通过这个导入,您可以立即像使用典型的 NumPy 程序一样使用 JAX,包括使用 NumPy 风格的数组创建函数、Python 函数和操作符,以及数组属性和方法:

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x)) 
[0\.        1.05      2.1       3.1499999 4.2      ] 

一旦您开始深入研究,您会发现 JAX 数组和 NumPy 数组之间存在一些差异;这些差异在 🔪 JAX - The Sharp Bits 🔪 中进行了探讨。

使用jax.jit()进行即时编译

JAX 可以在 GPU 或 TPU 上透明运行(如果没有,则退回到 CPU)。然而,在上述示例中,JAX 是一次将核心分派到芯片上的操作。如果我们有一系列操作,我们可以使用 jax.jit() 函数将这些操作一起编译为 XLA。

我们可以使用 IPython 的 %timeit 快速测试我们的 selu 函数,使用 block_until_ready() 来考虑 JAX 的动态分派(请参阅异步分派):

from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready() 
2.84 ms ± 9.23 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 

(请注意,我们已经使用 jax.random 生成了一些随机数;有关如何在 JAX 中生成随机数的详细信息,请查看伪随机数)。

我们可以使用 jax.jit() 转换来加速此函数的执行,该转换将在首次调用 selu 时进行 JIT 编译,并在此后进行缓存。

from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready() 
844 μs ± 2.73 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

上述时间表示在 CPU 上执行,但同样的代码可以在 GPU 或 TPU 上运行,通常会有更大的加速效果。

欲了解更多关于 JAX 中 JIT 编译的信息,请查看即时编译。

使用 jax.grad() 计算导数

除了通过 JIT 编译转换函数外,JAX 还提供其他转换功能。其中一种转换是 jax.grad(),它执行自动微分 (autodiff)

from jax import grad
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small)) 
[0.25       0.19661197 0.10499357] 

让我们用有限差分来验证我们的结果是否正确。

def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small)) 
[0.24998187 0.1965761  0.10502338] 

grad()jit() 转换可以任意组合并混合使用。在上面的示例中,我们对 sum_logistic 进行了 JIT 编译,然后取了它的导数。我们可以进一步进行:

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) 
-0.0353256 

除了标量值函数外,jax.jacobian() 转换还可用于计算向量值函数的完整雅可比矩阵:

from jax import jacobian
print(jacobian(jnp.exp)(x_small)) 
[[1\.        0\.        0\.       ]
 [0\.        2.7182817 0\.       ]
 [0\.        0\.        7.389056 ]] 

对于更高级的自动微分操作,您可以使用 jax.vjp() 来进行反向模式向量-雅可比积分,以及使用 jax.jvp()jax.linearize() 进行正向模式雅可比-向量积分。这两者可以任意组合,也可以与其他 JAX 转换组合使用。例如,jax.jvp()jax.vjp() 用于定义正向模式 jax.jacfwd() 和反向模式 jax.jacrev(),用于计算正向和反向模式下的雅可比矩阵。以下是组合它们以有效计算完整 Hessian 矩阵的一种方法:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small)) 
[[-0\.         -0\.         -0\.        ]
 [-0\.         -0.09085776 -0\.        ]
 [-0\.         -0\.         -0.07996249]] 

这种组合在实践中产生了高效的代码;这基本上是 JAX 内置的 jax.hessian() 函数的实现方式。

想了解更多关于 JAX 中的自动微分,请查看自动微分。

使用 jax.vmap() 进行自动向量化

另一个有用的转换是 vmap(),即向量化映射。它具有沿数组轴映射函数的熟悉语义,但与显式循环函数调用不同,它将函数转换为本地向量化版本,以获得更好的性能。与 jit() 组合时,它可以与手动重写函数以处理额外批处理维度的性能相媲美。

我们将处理一个简单的示例,并使用 vmap() 将矩阵-向量乘法提升为矩阵-矩阵乘法。虽然在这种特定情况下手动完成这一点很容易,但相同的技术也适用于更复杂的函数。

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
def apply_matrix(x):
  return jnp.dot(mat, x) 

apply_matrix 函数将一个向量映射到另一个向量,但我们可能希望将其逐行应用于矩阵。在 Python 中,我们可以通过循环遍历批处理维度来实现这一点,但通常导致性能不佳。

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready() 
Naively batched
962 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

熟悉 jnp.dot 函数的程序员可能会意识到,可以重写 apply_matrix 来避免显式循环,利用 jnp.dot 的内置批处理语义:

import numpy as np
@jit
def batched_apply_matrix(batched_x):
  return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready() 
Manually batched
14.3 μs ± 28.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) 

然而,随着函数变得更加复杂,这种手动批处理变得更加困难且容易出错。vmap() 转换旨在自动将函数转换为支持批处理的版本:

from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
  return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready() 
Auto-vectorized with vmap
21.7 μs ± 98.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) 

正如您所预期的那样,vmap() 可以与 jit()grad() 和任何其他 JAX 转换任意组合。

想了解更多关于 JAX 中的自动向量化,请查看自动向量化。

这只是 JAX 能做的一小部分。我们非常期待看到你用它做些什么!


JAX 中文文档(一)(2)https://developer.aliyun.com/article/1559830

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
3天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
11 1
|
3天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
10 1
|
3天前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
6 0
|
3天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
9 0
|
3天前
|
存储 机器学习/深度学习 并行计算
JAX 中文文档(二)(5)
JAX 中文文档(二)
8 0
|
3天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
3天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
10 0
|
3天前
|
机器学习/深度学习 异构计算 AI芯片
JAX 中文文档(七)(4)
JAX 中文文档(七)
7 0
|
3天前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
9 0
|
2天前
|
编译器 API 异构计算
JAX 中文文档(一)(2)
JAX 中文文档(一)
8 0