JAX 中文文档(十二)(4)https://developer.aliyun.com/article/1559717
JEP 18137:JAX NumPy 和 SciPy 包的范围
原文:
jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html
Jake VanderPlas
2023 年 10 月
到目前为止,jax.numpy
和 jax.scipy
的预期范围相对模糊。本文提出了这些包的明确定义范围,以更好地指导和评估未来的贡献,并促使移除一些超出范围的代码。
背景
从一开始,JAX 的目标是为在 XLA 中执行代码提供类似于 NumPy 的 API,项目的发展的一大部分是建立 jax.numpy
和 jax.scipy
命名空间,作为基于 JAX 的 NumPy 和 SciPy API 实现。一直有一个隐含的认识,即numpy
和 scipy
的某些部分超出了 JAX 的范围,但这一范围并没有明确定义。这可能会导致贡献者困惑和沮丧,因为对于潜在的 jax.numpy
和 jax.scipy
贡献是否会被接受,没有明确的答案。
为什么限制范围?
为了避免遗漏,我们应该明确一点:像 JAX 这样的项目中包含的任何代码都会为开发者带来一定的持续维护负担,虽然小但非零。项目长期成功直接与维护者能否继续为项目的所有部分承担维护工作有关:包括文档功能的记录、回答问题、修复错误等。对于任何软件工具的长期成功和可持续性,维护者必须仔细权衡每一项贡献是否对项目的目标和资源是净正面影响。
评估标准
本文提出了一个六轴评估标准,用来评判任何特定numpy
或scipy
API 的适用范围,以确定是否适合纳入 JAX。在所有轴上表现强劲的 API 是纳入 JAX 包的极佳候选;在六个轴中的任何一个上表现极差都是不适合纳入 JAX 的充分理由。
轴 1:XLA 对齐
我们考虑的第一个方向是建议 API 与本地 XLA 操作的对齐程度。例如,jax.numpy.exp()
函数几乎直接镜像了 jax.lax.exp
。numpy
、scipy.special
、numpy.linalg
、scipy.linalg
等中的大多数函数符合此标准:这类函数在考虑其是否应包含在 JAX 中时通过了 XLA 对齐检查。
另一方面,有些函数如numpy.unique()
,它们不直接对应任何 XLA 操作,在某些情况下甚至与 JAX 的当前计算模型根本不兼容,后者要求静态形状的数组(例如 unique
返回依赖于值的动态数组形状)。这类函数在考虑其是否应包含在 JAX 中时未能通过 XLA 对齐检查。
我们还考虑纯函数语义的必要性。例如,numpy.random
基于一个隐式更新的基于状态的随机数生成器,这与基于 XLA 的 JAX 计算模型根本不兼容。
轴 2:数组 API 对齐
我们考虑的第二个方向集中在Python 数组 API 标准上:在某些意义上,这是一个社区驱动的大纲,用于定义在各种用户社区中重要的面向数组编程的数组操作。如果numpy
或 scipy
中的 API 列在数组 API 标准中,这表明 JAX 应该包含它。以上述示例为例,数组 API 标准包含了 numpy.unique()
的多个变体(unique_all
、unique_counts
、unique_inverse
、unique_values
),这表明,尽管该函数与 XLA 的精确对齐并不完全,但它对于 Python 用户社区非常重要,因此 JAX 或许应该实现它。
轴 3:下游实现的存在
对于不符合 Axis 1 或 2 的功能,是否存在良好支持的下游包供应该功能是纳入 JAX 的一个重要考虑因素。一个很好的例子是 scipy.optimize
:虽然 JAX 包含了对 scipy.optimize
功能的最小包装集,但更完整的实现存在于由 JAX 协作者积极维护的 JAXopt 包中。在这种情况下,我们应倾向于指向用户和贡献者这些专业化的包,而不是在 JAX 自身重新实现这些 API。
Axis 4: Implementation 的复杂性与健壮性
对于不符合 XLA 的功能,一个考虑因素是提议实现的复杂程度。这在某种程度上与 Axis 1 一致,但仍然是需要强调的。有许多函数已经贡献给 JAX,它们具有相对复杂的实现,难以验证并引入了过多的维护负担;一个例子是 jax.scipy.special.bessel_jn()
:在撰写本 JEP 时,其当前实现是一个非直观的迭代逼近,存在 某些领域的收敛问题,而 提出的修复方案 则引入了更多的复杂性。如果在接受贡献时更加仔细地权衡了实现的复杂性和健壮性,我们可能选择不接受这个包的贡献。
Axis 5: 功能型 vs. 面向对象的 API
JAX 最适合使用功能型 API 而不是面向对象的 API。面向对象的 API 通常会隐藏不纯的语义,使其往往难以实现良好。NumPy 和 SciPy 通常坚持使用功能型 API,但有时提供面向对象的便利包装器。
例如 numpy.polynomial.Polynomial
,它包装了像 numpy.polyadd()
,numpy.polydiv()
等低级操作。一般情况下,当既有功能型 API 又有面向对象 API 时,JAX 应避免为面向对象 API 提供包装器,而应为功能型 API 提供包装器。
在只存在面向对象的 API 的情况下,JAX 应避免提供包装器,除非在其他轴上有很强的案例支持。
Axis 6: 对 JAX 用户和利益相关者的“重要性”
决定在 JAX 中包含 NumPy/SciPy API 还应考虑到该算法对一般用户社区的重要性。诚然,很难量化谁是“利益相关者”以及如何衡量这种重要性;但我们包括这一点是为了明确说明,在 JAX 的 NumPy 和 SciPy 包装中包含什么的任何决定都将涉及某种不容易量化的自由裁量权。
对于现有 API,通过在 github 中搜索使用情况可能有助于确定其重要性或缺失;例如,我们可以回到上面讨论过的 jax.scipy.special.bessel_jn()
:搜索显示,这个函数在 github 上仅有少数用例,这可能部分原因与先前提到的精度问题有关。
评估:什么在范围内?
在本节中,我们将尝试根据上述标准评估 NumPy 和 SciPy 的 API,包括当前 JAX API 中的一些示例。这不会是所有现有函数和类的详尽列表,而是一个更一般的子模块和主题讨论,附带相关示例。
NumPy API
✅ numpy
命名空间
我们认为主要 numpy
命名空间中的函数基本上都适用于 JAX,因为它与 XLA(轴 1)和 Python 数组 API(轴 2)的一般对齐性以及对 JAX 用户社区的一般重要性(轴 6)保持一致。一些函数可能处于边界地带(例如 numpy.intersect1d()
,np.setdiff1d()
,np.union1d()
可能在某些标准下不完全符合),但为简单起见,我们声明所有主要 numpy 命名空间中的数组函数都适用于 JAX。
✅ numpy.linalg
和 numpy.fft
numpy.linalg
和 numpy.fft
子模块包含许多与 XLA 提供的功能广泛对齐的函数。其他函数具有复杂的特定设备的低级实现,但代表一种情况,其中对利益相关者的重要性(轴 6)超过了复杂性。因此,我们认为这两个子模块都适用于 JAX。
❌ numpy.random
numpy.random
对于 JAX 而言超出范围,因为基于状态的随机数生成器与 JAX 的计算模型基本不兼容。相反,我们将重点放在 jax.random
上,它使用基于计数器的伪随机数生成器提供类似的功能。
❌ numpy.ma
和 numpy.polynomial
numpy.ma
和 numpy.polynomial
子模块主要关注通过其他函数手段表达的计算的面向对象接口(轴 5)。因此,我们认为它们不适用于 JAX。
❌ numpy.testing
NumPy 的测试功能只对主机端计算有意义,因此我们在 JAX 中不包含任何包装器。尽管如此,JAX 数组与 numpy.testing
兼容,并且在整个 JAX 测试套件中频繁使用它。
SciPy API
SciPy 没有顶层命名空间中的函数,但包含多个子模块。我们逐一考虑每个子模块,略过已弃用的模块。
❌ scipy.cluster
scipy.cluster
模块包含用于层次聚类、K 均值和相关算法的工具。这些在多个方面表现不佳,更适合由下游包处理。JAX 中已经存在一个函数(jax.scipy.cluster.vq.vq()
),但在 github 上没有明显的使用示例,这表明聚类对于 JAX 用户并不广泛重要。
建议:弃用并移除 jax.scipy.cluster.vq()
。
❌ scipy.constants
scipy.constants
模块包含数学和物理常数。这些常数可以直接在 JAX 中使用,因此没有必要在 JAX 中重新实现。
❌ scipy.datasets
scipy.datasets
模块包含获取和加载数据集的工具。这些获取的数据集可以直接在 JAX 中使用,因此没有必要在 JAX 中重新实现。
✅ scipy.fft
scipy.fft
模块包含与 XLA 提供的功能大致对齐的函数,并且在其他方面表现良好。因此,我们认为它们适用于 JAX 的范围内。
❌ scipy.integrate
scipy.integrate
模块包含用于数值积分的函数。其中更复杂的函数(quad
、dblquad
、ode
)基于动态评估的循环算法,根据轴 1 和 4 应视为 JAX 范围之外。jax.experimental.ode.odeint()
相关,但相当有限,未处于任何活跃开发状态。
JAX 当前确实包括 jax.scipy.integrate.trapezoid()
,但这仅因为numpy.trapz()
最近已弃用,推荐使用此功能。对于任何特定输入,其实现可以用一行 jax.numpy
表达式替换,因此它并不是提供的特别有用的 API。
基于轴 1、2、4 和 6,scipy.integrate
应被视为 JAX 范围之外。
建议:移除 jax.scipy.integrate.trapezoid()
,此功能已在 JAX 0.4.14 中添加。
❌ scipy.interpolate
scipy.interpolate
模块提供了在一维或多维中进行插值的低级和面向对象的例程。从多个角度评估,这些 API 表现不佳:它们基于类而非低级,除了最简单的方法外,无法有效地用 XLA 操作表达。
JAX 当前具有 scipy.interpolate.RegularGridInterpolator
的包装器。如果今天考虑此贡献,我们可能会根据以上标准拒绝它。但此代码相当稳定,因此继续维护没有太大的风险。
未来,我们应考虑将 scipy.interpolate
的其他成员视为 JAX 范围之外。
❌ scipy.io
scipy.io
子模块涉及文件输入/输出。在 JAX 中重新实现这一功能没有必要。
✅ scipy.linalg
scipy.linalg
子模块包含与 XLA 提供的功能大致对应的函数,快速线性代数对 JAX 用户社区至关重要。因此,我们认为它适用于 JAX 的范围之内。
❌ scipy.ndimage
scipy.ndimage
子模块包含一组用于处理图像数据的工具。其中许多与 scipy.signal
中的工具重叠(例如卷积和滤波)。JAX 目前在 jax.scipy.ndimage.map_coordinates()
中提供了一个 scipy.ndimage
API。此外,JAX 在 jax.image
模块中提供了一些与图像相关的工具。DeepMind 生态系统包括 dm-pix,一个更全面的用于在 JAX 中进行图像处理的工具集。考虑到所有这些因素,我建议 scipy.ndimage
应被视为 JAX 核心之外的范畴;我们可以将感兴趣的用户和贡献者指向 dm-pix。我们可以考虑将 map_coordinates
移至 dm-pix
或其他适当的包中。
❌ scipy.odr
scipy.odr
模块提供了一个面向对象的 ODRPACK
包装器,用于执行正交距离回归。目前尚不清楚是否可以使用现有的 JAX 原语清晰地表达这一功能,因此我们认为它超出了 JAX 本身的范畴。
❌ scipy.optimize
scipy.optimize
模块提供了用于优化的高级和低级接口。这样的功能对许多 JAX 用户非常重要,在 JAX 创建 jax.scipy.optimize
包装器时非常早就开始。然而,这些程序的开发人员很快意识到 scipy.optimize
API 过于约束,并且不同的团队开始开发 JAXopt 包和 Optimistix 包,每个包都包含了在 JAX 中更全面和经过更好测试的优化程序集。
由于这些受到良好支持的外部包,我们现在认为 scipy.optimize
超出了 JAX 的范围。
建议:弃用 jax.scipy.optimize
或使其成为一个轻量级的包装器,周围包装一个可选的 JAXopt 或 Optimistix 依赖。
🟡 scipy.signal
scipy.signal
模块则有所不同:一些函数完全适用于 JAX(例如correlate
和convolve
,这些函数是lax.conv_general_dilated
的更友好的包装),而其他许多函数则完全不适用于 JAX(专门领域的工具没有合适的降低路径到 XLA)。对于jax.scipy.signal
的潜在贡献将需要具体问题具体分析。
🟡 scipy.sparse
scipy.sparse
子模块主要包含了多种格式的稀疏矩阵和数组的存储和操作数据结构。此外,scipy.sparse.linalg
还包含了许多无矩阵的求解器,适用于稀疏矩阵、稠密矩阵和线性算子。
scipy.sparse
的数组和矩阵数据结构也超出了 JAX 的范围,因为它们与 JAX 的计算模型不符(例如,许多操作依赖于动态大小的缓冲区)。JAX 已经开发了jax.experimental.sparse
模块作为一组更符合 JAX 计算约束的替代数据结构。因此,我们认为scipy.sparse
中的数据结构超出了 JAX 的范围。
另一方面,scipy.sparse.linalg
已经被证明是一个有趣的领域,jax.scipy.sparse.linalg
包括了bicgstab
、cg
和gmres
求解器。这些对于 JAX 用户社区(轴 6)非常有用,但在其他轴上并不适用。它们非常适合移入一个下游库;一个潜在的选择可能是Lineax,它包括了多个基于 JAX 构建的线性求解器。
建议:考虑将稀疏求解器移入 Lineax,并且将scipy.sparse
视为 JAX 范围外的内容。
❌ scipy.spatial
scipy.spatial
模块主要包含面向对象的空间/距离计算和最近邻搜索接口。这在很大程度上超出了 JAX 的范围。
scipy.spatial.transform
子模块提供了用于操作三维空间旋转的工具。这是一个相对复杂的面向对象接口,也许最好由下游项目更好地服务。JAX 目前在jax.scipy.spatial.transform
中部分实现了Rotation
和Slerp
;这些是对基本函数的面向对象包装器,引入了非常庞大的 API 表面,且使用者非常少。我们认为它们超出了 JAX 本身的范围,用户最好由一个假设的下游项目更好地服务。
scipy.spatial.distance
子模块包含一组有用的距离度量标准,可能会诱人地为这些提供 JAX 包装器。尽管如此,通过jit
和vmap
,用户可以很容易地根据需要从头开始定义大多数这些的高效版本,因此将它们添加到 JAX 中并不特别有益。
建议:考虑废弃和移除Rotation
和Slerp
API,并考虑将scipy.spatial
整体视为不适合未来贡献。
✅ scipy.special
scipy.special
模块包括一些更专业函数的实现。在许多情况下,这些函数完全在范围内:例如,像gammaln
、betainc
、digamma
和许多其他函数直接对应于可用的 XLA 基元,并且明显在轴 1 和其他轴上在范围内。
其他函数需要更复杂的实现;一个上面提到的例子是bessel_jn
。尽管在轴 1 和 2 上不对齐,但这些函数往往在轴 6 上非常强大:scipy.special
提供了在多个领域中进行计算所需的基本函数,因此即使是具有复杂实现的函数,只要实现良好且健壮,也应倾向于在范围内。
有一些现有的函数包装器值得我们更仔细地看一看;例如:
jax.scipy.special.lpmn()
: 这个函数通过一个复杂的fori_loop
生成 Legendre 多项式,其方式与 scipy 的 API 不匹配(例如,对于scipy
,z
必须是标量,而对于 JAX,则z
必须是 1D 数组)。该函数有少数发现的用途,使其成为 Axes 1、2、4 和 6 上的一个薄弱候选者。jax.scipy.special.lpmn_values()
: 这与上述的lmpn
有类似的弱点。jax.scipy.special.sph_harm()
:此函数基于 lpmn 构建,其 API 与对应的scipy
函数不同。jax.scipy.special.bessel_jn()
:如上述第 4 轴中讨论的那样,这在实现的健壮性方面存在弱点,使用较少。我们可能会考虑用一个新的、更健壮的实现替换它(例如 #17038)。
建议:重构并提高bessel_jn
的健壮性和测试覆盖率。如果无法修改以更接近scipy
的 API,则考虑废弃lpmn
、lpmn_values
和sph_harm
。
✅ scipy.stats
scipy.stats
模块包含广泛的统计函数,包括离散和连续分布、汇总统计以及假设检验。JAX 目前在jax.scipy.stats
中包装了其中一些,主要包括大约 20 种统计分布以及一些其他函数(如mode
、rankdata
、gaussian_kde
)。总体来说,这些与 JAX 很好地对齐:分布通常可以用高效的 XLA 操作表达,API 清晰且功能齐全。
目前我们没有任何假设检验函数的包装器,这可能是因为这些对于 JAX 的主要用户群体不太有用。
关于分布,在某些情况下,tensorflow_probability
提供类似的功能,未来我们可能会考虑是否应该废弃 scipy.stats 中的分布以支持这种实现。
建议:未来,我们应将统计分布和汇总统计视为范围内的内容,并考虑假设检验及其相关功能通常不在范围内。
调查回归
原文:
jax.readthedocs.io/en/latest/investigating_a_regression.html
所以,您更新了 JAX,并且遇到了速度回归?您有一点时间并且准备好调查吗?让我们首先提一个 JAX 问题。但如果您能够确定触发回归的提交,这将确实帮助我们。
本文说明了我们如何确定导致15% 性能回归的提交。
步骤
如果复现器足够快,这可以很容易地完成。这是一种蛮力方法而非二分法,但如果复现器足够快,它会很有效。这确保了您始终测试兼容的 XLA 和 JAX 提交。它还限制了 XLA 的重新编译。
这里是建议的调查策略:
- 您可以在两个版本之间的每日容器上进行蛮力测试。
- 每小时重新编译,同时保持 XLA 和 JAX 的同步。
- 最终验证:也许需要手动检查几个提交(或者使用 git bisect)。
每日调查。
这可以通过使用JAX-Toolbox 每夜容器来完成。
- 有些日子,错误会阻止容器的构建,或者会出现临时回归。请忽略这些日子。
- 因此,您应该最终得到出现回归的具体日期或几天。
- 要自动化这个过程,您需要两个 Python 脚本:
- test_runner.sh: 将启动容器和测试。
- test.sh: 将安装缺失的依赖项并运行测试。
这里是用于该问题的真实示例脚本:github.com/google/jax/issues/17686
- test_runner.sh:
for m in 7 8 9; do for d in `seq -w 1 30`; do docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-0${m}-${d} /bin/bash /dir/test.sh &> OUT-0${m}-${d} done Done
- test.sh:
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax git clone https://github.com/Autodesk/XLB cd XLB export PYTHONPATH=. export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed python3 examples/performance/MLUPS3d.py 256 200
然后,您可以对每个输出执行 grep 命令以查看回归发生的时间:grep MLUPS OUT*
。这是我们得到的结果:
OUT-07-06:MLUPS: 587.9240990200157 OUT-07-07:MLUPS: 587.8907972116419 OUT-07-08:MLUPS: 587.3186499464459 OUT-07-09:MLUPS: 587.3130127722537 OUT-07-10:MLUPS: 587.8526619429658 OUT-07-17:MLUPS: 570.1631097290182 OUT-07-18:MLUPS: 570.2819775617064 OUT-07-19:MLUPS: 570.1672213357352 OUT-07-20:MLUPS: 587.437153685251 OUT-07-21:MLUPS: 587.6702557143142 OUT-07-25:MLUPS: 577.3063618431178 OUT-07-26:MLUPS: 577.2362978080912 OUT-07-27:MLUPS: 577.2101850145785 OUT-07-28:MLUPS: 577.0716349809895 OUT-07-29:MLUPS: 577.4223280707176 OUT-07-30:MLUPS: 577.2255967221336 OUT-08-01:MLUPS: 577.277685388252 OUT-08-02:MLUPS: 577.0137874289354 OUT-08-03:MLUPS: 577.1333281553946 OUT-08-04:MLUPS: 577.305012020407 OUT-08-05:MLUPS: 577.2143988866626 OUT-08-06:MLUPS: 577.2409145495443 OUT-08-07:MLUPS: 577.2602819927345 OUT-08-08:MLUPS: 577.2823738293221 OUT-08-09:MLUPS: 577.3453199728248 OUT-08-11:MLUPS: 577.3161423260563 OUT-08-12:MLUPS: 577.1697775786824 OUT-08-13:MLUPS: 577.3049883393633 OUT-08-14:MLUPS: 576.9051978525331 OUT-08-15:MLUPS: 577.5331743016213 OUT-08-16:MLUPS: 577.5117505070573 OUT-08-18:MLUPS: 577.5930698237612 OUT-08-19:MLUPS: 577.3539885757353 OUT-08-20:MLUPS: 577.4190113959127 OUT-08-21:MLUPS: 577.300394253605 OUT-08-22:MLUPS: 577.4263792037783 OUT-08-23:MLUPS: 577.4087536357031 OUT-08-24:MLUPS: 577.1094728438082 OUT-08-25: File "/XLB/examples/performance/MLUPS3d.py", line 5, in <module> OUT-08-26:MLUPS: 537.0164618489928 OUT-08-27:MLUPS: 536.9545448661609 OUT-08-28:MLUPS: 536.2887650464874 OUT-08-29:MLUPS: 536.7178471720636 OUT-08-30:MLUPS: 536.6978912984252 OUT-09-01:MLUPS: 536.7030899164106 OUT-09-04:MLUPS: 536.5339818238837 OUT-09-05:MLUPS: 536.6507808565617 OUT-09-06:MLUPS: 536.7144494518315 OUT-09-08:MLUPS: 536.7376612408998 OUT-09-09:MLUPS: 536.7798324141778 OUT-09-10:MLUPS: 536.726157440174 OUT-09-11:MLUPS: 536.7446210750584 OUT-09-12:MLUPS: 536.6707332269023 OUT-09-13:MLUPS: 536.6777936517823 OUT-09-14:MLUPS: 536.7581523280307 OUT-09-15:MLUPS: 536.6156273667873 OUT-09-16:MLUPS: 536.7320935035265 OUT-09-17:MLUPS: 536.7104991444398 OUT-09-18:MLUPS: 536.7492269469092 OUT-09-19:MLUPS: 536.6760131792959 OUT-09-20:MLUPS: 536.7361260076634
这发现 8-24 是好的,但 8-26 是坏的。在 8-25 上有另一个问题,阻止了获取结果。因此,我们需要在 8-24 和 8-26 之间的每小时进行调查。较早的减速可以忽略,仅需在这些日期之间再进行一次小时调查即可。
每小时调查。
这在两个日期之间的每个小时检出 JAX 和 XLA,重建所有内容并运行测试。这些脚本结构不同。我们启动工作容器并保持它。然后在容器内,我们只触发增量 XLA 构建,第一次构建除外。因此,在第一次迭代后速度要快得多。
- test_runner2.sh:
# Execute this script inside the container: # docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash cd /opt/xla-source git remote update cd /opt/jax-source git remote update pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax cd /tmp git clone https://github.com/Autodesk/XLB cd XLB for d in `seq -w 24 26`; do for h in `seq -w 0 24`; do echo $m $d $h /bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h done done
- test2.sh:
echo "param: $@" cd /opt/xla-source git checkout `git rev-list -1 --before="$*" origin/main` git show -q cd /opt/jax-source git checkout `git rev-list -1 --before="$*" origin/main` git show -q rm /opt/jax-source/dist/jax*.whl build-jax.sh # The script is in the nightly container export PYTHONPATH=. export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed python3 examples/performance/MLUPS3d.py 256 200
现在,您可以在新的输出文件上执行 grep 命令,查看问题出现的小时。
最终验证
通过这样,您需要检查这些小时之间的 JAX 和 XLA 历史记录。也许有几个提交需要测试。如果您想要花哨一点,可以使用 git bisect。
是否可以改进这个过程?
是的!如果这是一个崩溃回归,能够进行二分法测试将非常有用。但这会更加复杂。如果有人想贡献这样的说明,请提交 PR 😉
对于速度回归,二分法可以隐藏一些信息。我们不会那么容易地看到这里有两个回归。
0899164106
OUT-09-04:MLUPS: 536.5339818238837
OUT-09-05:MLUPS: 536.6507808565617
OUT-09-06:MLUPS: 536.7144494518315
OUT-09-08:MLUPS: 536.7376612408998
OUT-09-09:MLUPS: 536.7798324141778
OUT-09-10:MLUPS: 536.726157440174
OUT-09-11:MLUPS: 536.7446210750584
OUT-09-12:MLUPS: 536.6707332269023
OUT-09-13:MLUPS: 536.6777936517823
OUT-09-14:MLUPS: 536.7581523280307
OUT-09-15:MLUPS: 536.6156273667873
OUT-09-16:MLUPS: 536.7320935035265
OUT-09-17:MLUPS: 536.7104991444398
OUT-09-18:MLUPS: 536.7492269469092
OUT-09-19:MLUPS: 536.6760131792959
OUT-09-20:MLUPS: 536.7361260076634
这发现 8-24 是好的,但 8-26 是坏的。在 8-25 上有另一个问题,阻止了获取结果。因此,我们需要在 8-24 和 8-26 之间的每小时进行调查。较早的减速可以忽略,仅需在这些日期之间再进行一次小时调查即可。 ## 每小时调查。 这在两个日期之间的每个小时检出 JAX 和 XLA,重建所有内容并运行测试。这些脚本结构不同。我们启动工作容器并保持它。然后在容器内,我们只触发增量 XLA 构建,第一次构建除外。因此,在第一次迭代后速度要快得多。 + test_runner2.sh: ```py # Execute this script inside the container: # docker run -v $PWD:/dir --gpus=all ghcr.io/nvidia/jax:nightly-2023-08-24 /bin/bash cd /opt/xla-source git remote update cd /opt/jax-source git remote update pip install jmp pyvista numpy matplotlib Rtree trimesh jmp termcolor orbax cd /tmp git clone https://github.com/Autodesk/XLB cd XLB for d in `seq -w 24 26`; do for h in `seq -w 0 24`; do echo $m $d $h /bin/bash /dir/test2.sh Aug $d 2023 $h:00:00 &> OUT-08-${d}-$h done done
- test2.sh:
echo "param: $@" cd /opt/xla-source git checkout `git rev-list -1 --before="$*" origin/main` git show -q cd /opt/jax-source git checkout `git rev-list -1 --before="$*" origin/main` git show -q rm /opt/jax-source/dist/jax*.whl build-jax.sh # The script is in the nightly container export PYTHONPATH=. export CUDA_VISIBLE_DEVICES=0 # only 1 GPU is needed python3 examples/performance/MLUPS3d.py 256 200
现在,您可以在新的输出文件上执行 grep 命令,查看问题出现的小时。
最终验证
通过这样,您需要检查这些小时之间的 JAX 和 XLA 历史记录。也许有几个提交需要测试。如果您想要花哨一点,可以使用 git bisect。
是否可以改进这个过程?
是的!如果这是一个崩溃回归,能够进行二分法测试将非常有用。但这会更加复杂。如果有人想贡献这样的说明,请提交 PR 😉
对于速度回归,二分法可以隐藏一些信息。我们不会那么容易地看到这里有两个回归。