JAX 中的广义卷积
[外链图片转存中…(img-VNybejd1-1718950765073)]
JAX 提供了多种接口来跨数据计算卷积,包括:
jax.numpy.convolve()
(也有jax.numpy.correlate()
)jax.scipy.signal.convolve()
(也有correlate()
)jax.scipy.signal.convolve2d()
(也有correlate2d()
)jax.lax.conv_general_dilated()
对于基本的卷积操作,jax.numpy
和 jax.scipy
的操作通常足够使用。如果要进行更一般的批量多维卷积,jax.lax
函数是你应该开始的地方。
基本的一维卷积
基本的一维卷积由jax.numpy.convolve()
实现,它为numpy.convolve()
提供了一个 JAX 接口。这里是通过卷积实现的简单一维平滑的例子:
import matplotlib.pyplot as plt from jax import random import jax.numpy as jnp import numpy as np key = random.key(1701) x = jnp.linspace(0, 10, 500) y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,)) window = jnp.ones(10) / 10 y_smooth = jnp.convolve(y, window, mode='same') plt.plot(x, y, 'lightgray') plt.plot(x, y_smooth, 'black');
mode
参数控制如何处理边界条件;这里我们使用mode='same'
确保输出与输入大小相同。
欲了解更多信息,请参阅jax.numpy.convolve()
文档,或与原始numpy.convolve()
函数相关的文档。
基本的N维卷积
对于N维卷积,jax.scipy.signal.convolve()
提供了类似于jax.numpy.convolve()
的界面,推广到N维。
例如,这里是一种使用高斯滤波器进行图像去噪的简单方法:
from scipy import misc import jax.scipy as jsp fig, ax = plt.subplots(1, 3, figsize=(12, 5)) # Load a sample image; compute mean() to convert from RGB to grayscale. image = jnp.array(misc.face().mean(-1)) ax[0].imshow(image, cmap='binary_r') ax[0].set_title('original') # Create a noisy version by adding random Gaussian noise key = random.key(1701) noisy_image = image + 50 * random.normal(key, image.shape) ax[1].imshow(noisy_image, cmap='binary_r') ax[1].set_title('noisy') # Smooth the noisy image with a 2D Gaussian smoothing kernel. x = jnp.linspace(-3, 3, 7) window = jsp.stats.norm.pdf(x) * jsp.stats.norm.pdf(x[:, None]) smooth_image = jsp.signal.convolve(noisy_image, window, mode='same') ax[2].imshow(smooth_image, cmap='binary_r') ax[2].set_title('smoothed');
/tmp/ipykernel_1464/4118182506.py:7: DeprecationWarning: scipy.misc.face has been deprecated in SciPy v1.10.0; and will be completely removed in SciPy v1.12.0\. Dataset methods have moved into the scipy.datasets module. Use scipy.datasets.face instead. image = jnp.array(misc.face().mean(-1))
如同一维情况,我们使用mode='same'
指定如何处理边缘。有关N维卷积中可用选项的更多信息,请参阅jax.scipy.signal.convolve()
文档。
广义卷积
对于在构建深度神经网络中通常有用的更一般类型的批量卷积,JAX 和 XLA 提供了非常通用的 N 维conv_general_dilated函数,但如何使用它并不是很明显。我们将给出一些常见用例的示例。
一篇关于卷积算术的家族调查,卷积算术指南,强烈推荐阅读!
让我们定义一个简单的对角边缘核:
# 2D kernel - HWIO layout kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32) kernel += jnp.array([[1, 1, 0], [1, 0,-1], [0,-1,-1]])[:, :, jnp.newaxis, jnp.newaxis] print("Edge Conv kernel:") plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
接下来我们将创建一个简单的合成图像:
# NHWC layout img = jnp.zeros((1, 200, 198, 3), dtype=jnp.float32) for k in range(3): x = 30 + 60*k y = 20 + 60*k img = img.at[0, x:x+10, y:y+10, k].set(1.0) print("Original Image:") plt.imshow(img[0]);
Original Image:
lax.conv
和 lax.conv_with_general_padding
这些是卷积的简单便捷函数
️⚠️ 便捷函数 lax.conv
,lax.conv_with_general_padding
假定 NCHW 图像和 OIHW 卷积核。
from jax import lax out = lax.conv(jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor jnp.transpose(kernel,[3,2,0,1]), # rhs = OIHW conv kernel tensor (1, 1), # window strides 'SAME') # padding mode print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 200, 198) First output channel:
out = lax.conv_with_general_padding( jnp.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor jnp.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor (1, 1), # window strides ((2,2),(2,2)), # general padding 2x2 (1,1), # lhs/image dilation (1,1)) # rhs/kernel dilation print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(np.array(out)[0,0,:,:]);
out shape: (1, 3, 202, 200) First output channel:
维度编号定义了 conv_general_dilated
的维度布局
重要的参数是轴布局的三元组:(输入布局,卷积核布局,输出布局)
- N - 批次维度
- H - 空间高度
- W - 空间宽度
- C - 通道维度
- I - 卷积核 输入 通道维度
- O - 卷积核 输出 通道维度
⚠️ 为了展示维度编号的灵活性,我们选择了 NHWC 图像和 HWIO 卷积核约定,如下所示 lax.conv_general_dilated
。
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
SAME 填充,无步长,无扩张
out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'SAME', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 200, 198, 3) First output channel:
VALID 填充,无步长,无扩张
out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'VALID', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "DIFFERENT from above!") print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 198, 196, 3) DIFFERENT from above! First output channel:
SAME 填充,2,2 步长,无扩张
out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (2,2), # window strides 'SAME', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, " <-- half the size of above") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 100, 99, 3) <-- half the size of above First output channel:
VALID 填充,无步长,rhs 卷积核扩张 ~ 膨胀卷积(用于演示)
out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'VALID', # padding mode (1,1), # lhs/image dilation (12,12), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 176, 174, 3) First output channel:
VALID 填充,无步长,lhs=input 扩张 ~ 转置卷积
out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides ((0, 0), (0, 0)), # padding mode (2,2), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "<-- larger than original!") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 397, 393, 3) <-- larger than original! First output channel:
我们可以用最后一个示例,比如实现 转置卷积:
# The following is equivalent to tensorflow: # N,H,W,C = img.shape # out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1)) # transposed conv = 180deg kernel rotation plus LHS dilation # rotate kernel 180deg: kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1)) # need a custom output padding: padding = ((2, 1), (2, 1)) out = lax.conv_general_dilated(img, # lhs = image tensor kernel_rot, # rhs = conv kernel tensor (1,1), # window strides padding, # padding mode (2,2), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "<-- transposed_conv") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(np.array(out)[0,:,:,0]);
out shape: (1, 400, 396, 3) <-- transposed_conv First output channel:
1D 卷积
你不仅限于 2D 卷积,下面是一个简单的 1D 演示:
# 1D kernel - WIO layout kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], [[1, 1, 1], [-1, -1, -1]]], dtype=jnp.float32).transpose([2,1,0]) # 1D data - NWC layout data = np.zeros((1, 200, 2), dtype=jnp.float32) for i in range(2): for k in range(2): x = 35*i + 30 + 60*k data[0, x:x+30, k] = 1.0 print("in shapes:", data.shape, kernel.shape) plt.figure(figsize=(10,5)) plt.plot(data[0]); dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NWC', 'WIO', 'NWC')) print(dn) out = lax.conv_general_dilated(data, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,), # window strides 'SAME', # padding mode (1,), # lhs/image dilation (1,), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) plt.figure(figsize=(10,5)) plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2) ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1)) out shape: (1, 200, 2)
3D 卷积
import matplotlib as mpl # Random 3D kernel - HWDIO layout kernel = jnp.array([ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis] # 3D data - NHWDC layout data = jnp.zeros((1, 30, 30, 30, 1), dtype=jnp.float32) x, y, z = np.mgrid[0:1:30j, 0:1:30j, 0:1:30j] data += (jnp.sin(2*x*jnp.pi)*jnp.cos(2*y*jnp.pi)*jnp.cos(2*z*jnp.pi))[None,:,:,:,None] print("in shapes:", data.shape, kernel.shape) dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC')) print(dn) out = lax.conv_general_dilated(data, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1,1), # window strides 'SAME', # padding mode (1,1,1), # lhs/image dilation (1,1,1), # rhs/kernel dilation dn) # dimension_numbers print("out shape: ", out.shape) # Make some simple 3d density plots: from mpl_toolkits.mplot3d import Axes3D def make_alpha(cmap): my_cmap = cmap(jnp.arange(cmap.N)) my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3 return mpl.colors.ListedColormap(my_cmap) my_cmap = make_alpha(plt.cm.viridis) fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap) ax.axis('off') ax.set_title('input') fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap) ax.axis('off') ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1) ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3)) out shape: (1, 30, 30, 30, 1)
开发者文档
JAX 欢迎来自社区的贡献。请查看以下各种安装指南,以作为开发人员设置,并且开发人员专注的资源,如 Jax Enhancement Proposals。
- 参与 JAX 开发
- 从源代码构建
- 内部 API
- Autodidax: 从零开始构建 JAX 核心
- JAX Enhancement Proposals (JEPs)
- 调查回归
贡献给 JAX
每个人都可以贡献到 JAX,并且我们重视每个人的贡献。有几种贡献方式,包括:
JAX 项目遵循Google 的开源社区准则。
贡献的方式
我们欢迎拉取请求,特别是对于那些标记有欢迎贡献或好的首次问题的问题。
对于其他建议,我们要求您首先在 GitHub 的问题或讨论中寻求对您计划贡献的反馈。
使用拉取请求贡献代码
我们所有的开发都是使用 git 进行的,所以假定您具备基本知识。
按照以下步骤贡献代码:
- 签署Google 贡献者许可协议 (CLA)。有关更多信息,请参阅下面的拉取请求检查清单。
- 在存储库页面上点击Fork按钮来分叉 JAX 存储库。这将在您自己的账户中创建 JAX 存储库的副本。
- 在本地安装 Python >= 3.9 以便运行测试。
- 使用
pip
从源码安装您的分支。这允许您修改代码并立即测试:
git clone https://github.com/YOUR_USERNAME/jax cd jax pip install -r build/test-requirements.txt # Installs all testing requirements. pip install -e ".[cpu]" # Installs JAX from the current directory in editable mode.
- 将 JAX 存储库添加为上游远程,以便您可以使用它来同步您的更改。
git remote add upstream https://www.github.com/google/jax
- 创建一个分支,在该分支上进行开发:
git checkout -b name-of-change
- 并使用您喜欢的编辑器实现您的更改(我们推荐Visual Studio Code)。
- 通过从存储库顶部运行以下命令来确保您的代码通过 JAX 的 lint 和类型检查:
pip install pre-commit pre-commit run --all
- 有关更多详细信息,请参阅代码规范和类型检查。
- 确保通过从存储库顶部运行以下命令来通过测试:
pytest -n auto tests/
- JAX 的测试套件非常庞大,因此如果您知道涵盖您更改的特定测试文件,您可以限制测试为该文件;例如:
pytest -n auto tests/lax_scipy_test.py
- 您可以使用
pytest -k
标志进一步缩小测试范围以匹配特定的测试名称:
pytest -n auto tests/lax_scipy_test.py -k testLogSumExp
- JAX 还提供了对运行哪些特定测试有更精细控制的方式;有关更多信息,请参阅运行测试。
- 一旦您对自己的更改感到满意,请按如下方式创建提交(如何编写提交消息):
git add file1.py file2.py ... git commit -m "Your commit message"
- 然后将您的代码与主存储库同步:
git fetch upstream git rebase upstream/main
- 最后,将您的提交推送到开发分支,并在您的分支中创建一个远程分支,以便从中创建拉取请求:
git push --set-upstream origin name-of-change
- 请确保您的贡献是一个单一提交(参见单一更改提交和拉取请求)
- 从 JAX 仓库创建一个拉取请求并发送进行审查。在准备您的 PR 时,请检查 JAX 拉取请求检查列表,并在需要更多关于使用拉取请求的信息时参考 GitHub 帮助。
JAX 拉取请求检查列表
当您准备一个 JAX 拉取请求时,请牢记以下几点:
Google 贡献者许可协议
参与此项目必须附有 Google 贡献者许可协议(CLA)。您(或您的雇主)保留对您贡献的版权;这只是让我们可以在项目的一部分中使用和重新分发您的贡献的许可。请访问 cla.developers.google.com/
查看您当前已有的协议或签署新协议。
通常您只需要提交一次 CLA,所以如果您已经提交过一个(即使是为不同的项目),您可能不需要再次提交。如果您不确定是否已签署了 CLA,您可以打开您的 PR,我们友好的 CI 机器人将为您检查。
单一更改提交和拉取请求
一个 git 提交应该是一个独立的、单一的更改,并带有描述性的消息。这有助于审查和在后期发现问题时识别或还原更改。
拉取请求通常由单一 git 提交组成。(在某些情况下,例如进行大型重构或内部重写时,可能会包含多个提交。)在准备进行审查的拉取请求时,如果可能的话,请提前将多个提交合并。可能会使用 git rebase -i
命令来实现这一点。### 代码风格检查和类型检查
JAX 使用 mypy 和 ruff 来静态测试代码质量;在本地运行这些检查的最简单方法是通过 pre-commit 框架:
pip install pre-commit pre-commit run --all
如果您的拉取请求涉及文档笔记本,请注意还将对其运行一些检查(有关更多详细信息,请参阅更新笔记本)。
完整的 GitHub 测试套件
您的 PR 将自动通过 GitHub CI 运行完整的测试套件,该套件涵盖了多个 Python 版本、依赖版本和配置选项。这些测试通常会发现您在本地没有捕捉到的失败;为了修复问题,您可以将新的提交推送到您的分支。
受限测试套件
一旦您的 PR 被审查通过,JAX 的维护者将其标记为 Pull Ready
。这将触发一系列更广泛的测试,包括在标准 GitHub CI 中不可用的 GPU 和 TPU 后端的测试。这些测试的详细结果不对公众可见,但负责审查您的 PR 的 JAX 维护者将与您沟通任何可能揭示的失败;例如,TPU 上的数值测试通常需要与 CPU 不同的容差。
从源代码构建
首先,获取 JAX 源代码:
git clone https://github.com/google/jax cd jax
构建 JAX 涉及两个步骤:
- 构建或安装用于
jax
的 C++支持库jaxlib
。 - 安装
jax
Python 包。
构建或安装jaxlib
使用 pip 安装jaxlib
如果您只修改了 JAX 的 Python 部分,我们建议使用 pip 从预构建的 wheel 安装jaxlib
:
pip install jaxlib
请参阅JAX 自述文件获取有关 pip 安装的完整指南(例如,用于 GPU 和 TPU 支持)。
从源代码构建jaxlib
要从源代码构建jaxlib
,还必须安装一些先决条件:
- C++编译器(g++、clang 或 MSVC)
在 Ubuntu 或 Debian 上,可以使用以下命令安装所需的先决条件:
sudo apt install g++ python python3-dev
- 如果你在 Mac 上进行构建,请确保安装了 XCode 和 XCode 命令行工具。
请参阅下面的 Windows 构建说明。 - 无需在本地安装 Python 依赖项,因为在构建过程中将忽略你的系统 Python;请查看有关管理封闭 Python 的详细信息。
要为 CPU 或 TPU 构建jaxlib
,可以运行:
python build/build.py pip install dist/*.whl # installs jaxlib (includes XLA)
要为与当前系统安装的 Python 版本不同的版本构建 wheel,请将--python_version
标志传递给构建命令:
python build/build.py --python_version=3.12
本文的其余部分假定你正在为与当前系统安装匹配的 Python 版本构建。如果需要为不同版本构建,只需每次调用python build/build.py
时附加--python_version=
标志。请注意,无论是否传递--python_version
参数,Bazel 构建始终将使用封闭的 Python 安装。
有两种方法可以使用 CUDA 支持构建jaxlib
:(1) 使用python build/build.py --enable_cuda
生成带有 cuda 支持的 jaxlib wheel,或者 (2) 使用python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
生成三个 wheel(不带 cuda 的 jaxlib,jax-cuda-plugin 和 jax-cuda-pjrt)。你可以将gpu_plugin_cuda_version
设置为 11 或 12。
查看python build/build.py --help
以获取配置选项,包括指定 CUDA 和 CUDNN 路径的方法,这些必须已安装。这里的python
应该是你的 Python 3 解释器的名称;在某些系统上,你可能需要使用python3
。尽管使用python
调用脚本,但 Bazel 将始终使用其自己的封闭 Python 解释器和依赖项,只有build/build.py
脚本本身将由你的系统 Python 解释器处理。默认情况下,wheel 将写入当前目录的dist/
子目录。
使用修改后的 XLA 存储库从源代码构建 jaxlib。
JAX 依赖于 XLA,其源代码位于XLA GitHub 存储库中。默认情况下,JAX 使用 XLA 存储库的固定副本,但在开发 JAX 时,我们经常希望使用本地修改的 XLA 副本。有两种方法可以做到这一点:
- 使用 Bazel 的
override_repository
功能,您可以将其作为命令行标志传递给build.py
,如下所示:
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
- 修改 JAX 源代码根目录中的
WORKSPACE
文件,以指向不同的 XLA 树。
要向 XLA 贡献更改,请向 XLA 代码库发送 PR。
JAX 固定的 XLA 版本定期更新,但在每次 jaxlib
发布之前会进行特定更新。
在 Windows 上从源代码构建 jaxlib
的附加说明
在 Windows 上,按照 安装 Visual Studio 的指南来设置 C++ 工具链。需要使用 Visual Studio 2019 版本 16.5 或更新版本。如果需要启用 CUDA 进行构建,请按照 CUDA 安装指南 设置 CUDA 环境。
JAX 构建使用符号链接,需要您激活 开发者模式。
您可以使用其 Windows 安装程序 安装 Python,或者如果您更喜欢,可以使用 Anaconda 或 Miniconda 设置 Python 环境。
Bazel 的某些目标使用 bash 实用程序进行脚本编写,因此需要 MSYS2。有关详细信息,请参阅 在 Windows 上安装 Bazel。安装以下软件包:
pacman -S patch coreutils
安装 coreutils 后,realpath 命令应存在于您的 shell 路径中。
安装完成后。打开 PowerShell,并确保 MSYS2 在当前会话的路径中。确保 bazel
、patch
和 realpath
可访问。激活 conda 环境。以下命令启用 CUDA 并进行构建,请根据您的需求进行调整:
python .\build\build.py ` --enable_cuda ` --cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` --cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` --cuda_version='10.1' ` --cudnn_version='7.6.5'
要添加调试信息进行构建,请加上标志 --bazel_options='--copt=/Z7'
。
为 AMD GPU 构建 ROCM jaxlib
的附加说明
您需要安装多个 ROCM/HIP 库以在 ROCM 上进行构建。例如,在具有 AMD 的 apt
存储库 的 Ubuntu 机器上,需要安装多个软件包:
sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
要使用 ROCM 支持构建 jaxlib,可以运行以下构建命令,并根据您的路径和 ROCM 版本进行适当调整。
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0
AMD 的 XLA 代码库分支可能包含在上游 XLA 代码库中不存在的修复程序。如果遇到上游代码库的问题,可以尝试使用 AMD 的分支,方法是克隆他们的代码库:
git clone https://github.com/ROCmSoftwarePlatform/xla.git
并使用以下命令覆盖构建 JAX 所用的 XLA 代码库:
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \ --bazel_options=--override_repository=xla=/path/to/xla-rocm
管理封闭 Python
为了确保 JAX 的构建可复制,并在支持的平台(Linux、Windows、MacOS)上表现一致,并且正确隔离于本地系统的特定细节,我们依赖于隔离的 Python(参见rules_python)来执行通过 Bazel 执行的所有构建和测试命令。这意味着在构建期间将忽略系统 Python 安装,并且 Python 解释器以及所有 Python 依赖项将由 bazel 直接管理。
指定 Python 版本
运行build/build.py
工具时,将自动设置隔离的 Python 版本,以匹配您用于运行build/build.py
脚本的 Python 版本。若要显式选择特定版本,可以向该工具传递--python_version
参数:
python build/build.py --python_version=3.12
在幕后,隔离的 Python 版本由HERMETIC_PYTHON_VERSION
环境变量控制,在运行build/build.py
时将自动设置。如果直接运行 bazel,则可能需要以以下某种方式显式设置该变量:
# Either add an entry to your `.bazelrc` file build --repo_env=HERMETIC_PYTHON_VERSION=3.12 # OR pass it directly to your specific build command bazel build <target> --repo_env=HERMETIC_PYTHON_VERSION=3.12 # OR set the environment variable globally in your shell: export HERMETIC_PYTHON_VERSION=3.12
您可以通过在运行之间简单切换--python_version
的值来在同一台机器上连续运行不同版本的 Python 进行构建和测试。构建缓存中的所有与 Python 无关的部分将保留并在后续构建中重用。
指定 Python 依赖项
在 bazel 构建期间,所有 JAX 的 Python 依赖项都被固定到它们的特定版本。这是确保构建可复制性所必需的。JAX 依赖项的完整传递闭包以及其相应的哈希在build/requirements_lock_.txt
文件中指定(例如,Python 3.12
的build/requirements_lock_3_12.txt
)。
要更新锁定文件,请确保build/requirements.in
包含所需的直接依赖项列表,然后执行以下命令(此命令将在幕后调用pip-compile):
python build/build.py --requirements_update --python_version=3.12
或者,如果需要更多控制,可以直接运行 bazel 命令(这两个命令是等效的):
bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12
其中3.12
是您希望更新的 Python 版本。
注意,由于仍然使用的是幕后的pip
和pip-compile
工具,因此大多数由这些工具支持的命令行参数和功能也将被 Bazel 要求更新命令所承认。例如,如果希望更新程序考虑预发布版本,只需将--pre
参数传递给 bazel 命令:
bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- --pre
指定本地构建的依赖项
如果需要依赖于本地的.whl
文件,例如您新构建的 jaxlib wheel,可以在build/requirements.in
中添加轮的路径,并重新运行所选 Python 版本的要求更新器命令。例如:
echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py --requirements_update --python_version=3.12
指定夜间构建的依赖项
为了构建和测试最新的、潜在不稳定的 Python 依赖关系集合,我们提供了一个特殊版本的依赖关系更新命令,如下所示:
python build/build.py --requirements_nightly_update --python_version=3.12
或者,如果你直接运行bazel
(这两个命令是等效的):
bazel run //build:requirements_nightly.update --repo_env=HERMETIC_PYTHON_VERSION=3.12
与常规更新程序的区别在于,默认情况下它会接受预发布、开发和夜间包,还将搜索 https://pypi.anaconda.org/scientific-python-nightly-wheels/simple 作为额外的索引 URL,并且不会在生成的要求锁文件中放置哈希值。
使用预发布版本的 Python 进行构建
我们支持所有当前版本的 Python,但如果你需要针对不同版本(例如尚未正式发布的最新不稳定版本)进行构建和测试,请按照以下说明操作。
- 确保你已安装构建 Python 解释器本身所需的必要 Linux 软件包,以及从源代码安装关键软件包(如
numpy
或scipy
)。在典型的 Debian 系统上,你可能需要安装以下软件包:
sudo apt-get update sudo apt-get build-dep python3 -y sudo apt-get install pkg-config zlib1g-dev libssl-dev -y # to build scipy sudo apt-get install libopenblas-dev -y
- 检查你的
WORKSPACE
文件,并确保其中有指向你想要构建的 Python 版本的custom_python_interpreter()
条目。 - 运行
bazel build @python_dev//:python_dev
来构建 Python 解释器。默认情况下,它将使用 GCC 编译器进行构建。如果你希望使用 clang 进行构建,则需要设置相应的环境变量(例如--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++
)。 - 检查上一个命令的输出。在其末尾,你会找到一个
python_register_toolchains()
入口的代码片段,其中包含你新构建的 Python。将该代码片段复制到你的WORKSPACE
文件中,可以选择是在python_init_toolchains()
入口后面(添加新版本的 Python),还是替换它(替换类似于 3.12 的现有版本,例如替换为 3.12 的自定义构建变体)。代码片段是根据你的实际设置生成的,因此应该可以直接使用,但如果需要,你可以自定义它(例如更改 Python.tgz
文件的位置,以便可以远程下载而不是本地机器上)。 - 确保在你的
WORKSPACE
文件中的python_init_repositories()
的requirements
参数中有关于你的 Python 版本的条目。例如,对于Python 3.13
,它应该有类似于"3.13": "//build:requirements_lock_3_13.txt"
的内容。 - 对于不稳定版本的 Python,可选择(但强烈建议)运行
bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"
,其中3.13
是您在第三步构建的 Python 解释器版本。这将使pip
从源代码拉取并构建 JAX 所有依赖的 Python 包(例如numpy
、scipy
、matplotlib
、zstandard
)。建议首先执行此步骤(即独立于实际 JAX 构建之外),以避免在构建 JAX 本身和其 Python 依赖项时发生冲突。例如,我们通常使用 clang 构建 JAX,但使用 clang 从源代码构建matplotlib
由于 GCC 和 clang 在链接时优化行为(通过-flto
标志触发的链接时优化)的差异而直接失败,默认情况下 matplotlib 默认假定 GCC。如果您针对稳定版本的 Python 进行构建,或者一般情况下不期望任何 Python 依赖项从源代码构建(即相应 Python 版本的二进制分发包已经存在于仓库中),则不需要执行此步骤。 - 恭喜,你已经为 JAX 项目构建和配置了自定义 Python!现在你可以像往常一样执行构建/测试命令,只需确保
HERMETIC_PYTHON_VERSION
环境变量已设置并指向你的新版本。 - 注意,如果你正在构建 Python 的预发布版本,则更新
requirements_lock_.txt
文件以与新构建的 Python 匹配可能会失败,因为软件包仓库没有相应的二进制包。当没有二进制包可用时,pip-compile
将继续从源代码构建,这可能会失败,因为其比在pip
安装期间执行同样操作更为严格。建议为不稳定版本的 Python 更新要求锁定文件的方法是更新最新稳定版本(例如3.12
)的要求(因此特殊的//build:requirements_dev.update
目标),然后将结果复制到不稳定 Python 的锁定文件(例如3.13
)中:
bazel run //build:requirements_dev.update --repo_env=HERMETIC_PYTHON_VERSION="3.12" cp build/requirements_lock_3_12.txt build/requirements_lock_3_13.txt bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13" # You may need to edit manually the resultant lock file, depending on how ready # your dependencies are for the new version of Python.
安装 jax
安装完成 jaxlib
后,可以通过运行以下命令安装 jax
:
pip install -e . # installs jax
要从 GitHub 升级到最新版本,只需从 JAX 仓库根目录运行 git pull
,然后通过运行 build.py
或必要时升级 jaxlib
进行重新构建。你不应该需要重新安装 jax
,因为 pip install -e
会设置从 site-packages 到仓库的符号链接。
运行测试
有两种支持的机制可以运行 JAX 测试,即使用 Bazel 或使用 pytest。
使用 Bazel
首先,通过运行以下命令配置 JAX 构建:
python build/build.py --configure_only
你可以向 build.py
传递额外选项以配置构建;请查看 jaxlib
构建文档获取详细信息。
默认情况下,Bazel 构建使用从源代码构建的 jaxlib
运行 JAX 测试。要运行 JAX 测试,请运行:
bazel test //tests:cpu_tests //tests:backend_independent_tests
如果您有必要的硬件,还可以使用//tests:gpu_tests
和//tests:tpu_tests
。
要使用预安装的jaxlib
而不是构建它,您首先需要在 hermetic Python 中使其可用。要在 hermetic Python 中安装特定版本的jaxlib
,请运行以下命令(以jaxlib >= 0.4.26
为例):
echo -e "\njaxlib >= 0.4.26" >> build/requirements.in python build/build.py --requirements_update
或者,要从本地 wheel 安装jaxlib
(假设 Python 3.12):
echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in python build/build.py --requirements_update --python_version=3.12
一旦在 hermetic 中安装了jaxlib
,请运行:
bazel test --//jax:build_jaxlib=false //tests:cpu_tests //tests:backend_independent_tests
可以使用环境变量来控制多个测试行为(参见下文)。环境变量可以通过--test_env=FLAG=value
标志传递给 Bazel 的 JAX 测试。
JAX 的一些测试适用于多个加速器(例如 GPU、TPU)。当 JAX 已安装时,您可以像这样运行 GPU 测试:
bazel test //tests:gpu_tests --local_test_jobs=4 --test_tag_filters=multiaccelerator --//jax:build_jaxlib=false --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
您可以通过在多个加速器上并行运行单个加速器测试来加速测试。这也会触发每个加速器的多个并发测试。对于 GPU,您可以像这样操作:
NB_GPUS=2 JOBS_PER_ACC=4 J=$((NB_GPUS * JOBS_PER_ACC)) MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX_ACCELERATOR_COUNT=${NB_GPUS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_ACC} --local_test_jobs=$J" bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU
使用pytest
首先,通过运行pip install -r build/test-requirements.txt
安装依赖项。
使用pytest
运行所有 JAX 测试时,建议使用pytest-xdist
,它可以并行运行测试。它作为pip install -r build/test-requirements.txt
命令的一部分安装。
从存储库根目录运行:
pytest -n auto tests
控制测试行为
JAX 以组合方式生成测试用例,您可以使用JAX_NUM_GENERATED_CASES
环境变量控制为每个测试生成和检查的案例数(默认为 10)。自动化测试当前默认使用 25 个。
例如,可以这样编写
# Bazel bazel test //tests/... --test_env=JAX_NUM_GENERATED_CASES=25`
或者
# pytest JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
自动化测试还使用默认的 64 位浮点数和整数运行测试(JAX_ENABLE_X64
):
JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25 pytest -n auto tests
您可以使用pytest的内置选择机制运行更具体的测试集,或者直接运行特定的测试文件以查看有关正在运行的案例的更详细信息:
JAX_NUM_GENERATED_CASES=5 python tests/lax_numpy_test.py
您可以通过传递环境变量JAX_SKIP_SLOW_TESTS=1
来跳过一些已知的运行缓慢的测试。
要指定从测试文件运行的特定一组测试,您可以通过--test_targets
标志传递字符串或正则表达式。例如,您可以使用以下命令运行jax.numpy.pad
的所有测试:
python tests/lax_numpy_test.py --test_targets="testPad"
Colab 笔记本在文档构建过程中会进行错误测试。
Doctests
JAX 使用 pytest 以 doctest 模式测试文档中的代码示例。您可以使用以下命令运行:
pytest docs
另外,JAX 以doctest-modules
模式运行 pytest,以确保函数文档字符串中的代码示例能够正确运行。例如,您可以在本地运行如下命令:
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
请注意,当在完整包上运行 doctest 命令时,有几个文件被标记为跳过;您可以在ci-build.yaml
中查看详细信息。
类型检查
我们使用 mypy
来检查类型提示。要像 CI 一样在本地检查类型:
pip install mypy mypy --config=pyproject.toml --show-error-codes jax
或者,您可以使用 pre-commit 框架在 git 存储库中的所有暂存文件上运行此命令,自动使用与 GitHub CI 中相同的 mypy 版本:
pre-commit run mypy
代码检查
JAX 使用 ruff linter 来确保代码质量。您可以通过运行以下命令检查本地更改:
pip install ruff ruff jax
或者,您可以使用 pre-commit 框架在 git 存储库中的所有暂存文件上运行此命令,自动使用与 GitHub 测试中相同的 ruff 版本:
pre-commit run ruff
更新文档
要重新构建文档,请安装几个包:
pip install -r docs/requirements.txt
然后运行:
sphinx-build -b html docs docs/build/html -j auto
这可能需要很长时间,因为它执行文档源中的许多笔记本;如果您希望在不执行笔记本的情况下构建文档,可以运行:
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
然后您可以在 docs/build/html/index.html
中看到生成的文档。
-j auto
选项控制构建的并行性。您可以使用数字替换 auto
,以控制使用多少 CPU 核心。
更新笔记本
我们使用 jupytext 来维护 docs/notebooks
中笔记本的两个同步副本:一个是 ipynb
格式,另一个是 md
格式。前者的优点是可以直接在 Colab 中打开和执行;后者的优点是在版本控制中更容易跟踪差异。
编辑 ipynb
对于对代码和输出进行重大修改的大型更改,最简单的方法是在 Jupyter 或 Colab 中编辑笔记本。要在 Colab 界面中编辑笔记本,请打开 colab.research.google.com
,从本地仓库上传
。根据需要更新,Run all cells
然后 Download ipynb
。您可能希望使用 sphinx-build
测试它是否正确执行,如上所述。
编辑 md
对于对笔记本文本内容进行较小更改的情况,最简单的方法是使用文本编辑器编辑 .md
版本。
同步笔记本
在编辑 ipynb 或 md 版本的笔记本后,您可以通过运行 jupytext --sync
来同步这两个版本的内容;例如:
pip install jupytext==1.16.0 jupytext --sync docs/notebooks/thinking_in_jax.ipynb
jupytext
版本应与 .pre-commit-config.yaml 中指定的版本匹配。
要检查 markdown 和 ipynb 文件是否正确同步,可以使用 pre-commit 框架执行与 github CI 相同的检查:
git add docs -u # pre-commit runs on files in git staging. pre-commit run jupytext
创建新的笔记本
如果您要向文档添加新的笔记本,并希望使用此处讨论的 jupytext --sync
命令,可以通过以下命令设置您的笔记本以使用 jupytext:
jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
这是通过在笔记本文件中添加一个 "jupytext"
元数据字段来实现的,该字段指定了所需的格式,并在调用 jupytext --sync
命令时被识别。
Sphinx 构建内的笔记本
一些笔记本是作为预提交检查的一部分和作为 Read the docs 构建的一部分自动生成的。如果单元格引发错误,则构建将失败。如果错误是有意的,您可以捕获它们,或者将单元格标记为 raises-exceptions
元数据(示例 PR)。您必须在 .ipynb
文件中手动添加此元数据。当其他人重新保存笔记本时,它将被保留。
我们排除一些笔记本的构建,例如,因为它们包含长时间的计算。请参阅 conf.py 中的 exclude_patterns
。
在 readthedocs.io
上构建文档
JAX 的自动生成文档位于 jax.readthedocs.io/
。
整个项目的文档构建受 readthedocs JAX settings 的控制。当前的设置在代码推送到 GitHub 的 main
分支后会触发文档构建。对于每个代码版本,构建过程由 .readthedocs.yml
和 docs/conf.py
配置文件驱动。
对于每个自动化文档构建,您可以查看 documentation build logs。
如果您想在 Readthedocs 上测试文档生成,请将代码推送到 test-docs
分支。该分支也将自动构建,并且您可以在这里查看生成的文档 here。如果文档构建失败,您可能希望 清除 test-docs 的构建环境。
在本地测试中,我能够在一个全新的目录中通过重放我在 Readthedocs 日志中看到的命令来完成:
mkvirtualenv jax-docs # A new virtualenv mkdir jax-docs # A new directory cd jax-docs git clone --no-single-branch --depth 50 https://github.com/google/jax cd jax git checkout --force origin/test-docs git clean -d -f -f workon jax-docs python -m pip install --upgrade --no-cache-dir pip python -m pip install --upgrade --no-cache-dir -I Pygments==2.3.1 setuptools==41.0.1 docutils==0.14 mock==1.0.1 pillow==5.4.1 alabaster>=0.7,<0.8,!=0.7.5 commonmark==0.8.1 recommonmark==0.5.0 'sphinx<2' 'sphinx-rtd-theme<0.5' 'readthedocs-sphinx-ext<1.1' python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt cd docs python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
Internal APIs
core
Jaxpr (constvars, invars, outvars, eqns[, …]) |
|
ClosedJaxpr (jaxpr, consts) |
JAX 中文文档(十)(2)https://developer.aliyun.com/article/1559708