JAX 中文文档(十一)(4)

简介: JAX 中文文档(十一)

JAX 中文文档(十一)(3)https://developer.aliyun.com/article/1559782


JAX 类型提升的设计

原文:jax.readthedocs.io/en/latest/jep/9407-type-promotion.html

Jake VanderPlas, December 2021

任何数值计算库设计中面临的挑战之一是如何处理不同类型值之间的操作选择。本文概述了 JAX 使用的提升语义背后的思维过程,总结在JAX 类型提升语义中。

JAX 类型提升的目标

JAX 的数值计算 API 是模仿 NumPy 的,但增加了一些功能,包括能够针对 GPU 和 TPU 等加速器进行优化。这使得采用  NumPy 的类型提升系统对 JAX 用户不利:NumPy 的类型提升规则严重偏向于 64 位输出,这对于加速器上的计算是有问题的。像 GPU 和  TPU 这样的设备通常需要付出显著的性能代价来使用 64 位浮点类型,并且在某些情况下根本不支持本地 64 位浮点类型。

这种问题类型提升语义的简单例子可以在 32 位整数和浮点数之间的二进制操作中看到:

import numpy as np
np.dtype(np.int32(1) + np.float32(1)) 
dtype('float64') 

NumPy 倾向于生成 64 位值是使用 NumPy API 进行加速计算的一个长期问题,目前还没有一个很好的解决方案。因此,JAX 已经开始重新思考以加速器为目标的 NumPy 风格类型提升。

回顾:表格和格子

在我们深入细节之前,让我们花点时间退后一步,思考如何思考类型提升问题。考虑 Python 内置数值类型(即intfloatcomplex)之间的算术操作,我们可以用几行代码生成 Python 用于这些类型值加法的类型提升表:

import pandas as pd
types = [int, float, complex]
name = lambda t: t.__name__
pd.DataFrame([[name(type(t1(1) + t2(1))) for t1 in types] for t2 in types],
             index=[name(t) for t in types], columns=[name(t) for t in types]) 
int float complex
int int float complex
float float float complex
complex complex complex complex

这张表详细列出了 Python 的数值类型提升行为,但事实证明有一种更为简洁的补充表示:表示法,其中任意两个节点之间的上确界是它们提升到的类型。Python 提升表的格表示法要简单得多:

显示代码单元格源代码 隐藏代码单元格源代码

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {'int': ['float'], 'float': ['complex']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'int': [0, 0], 'float': [1, 0], 'complex': [2, 0]}
fig, ax = plt.subplots(figsize=(8, 2))
nx.draw(graph, with_labels=True, node_size=4000, node_color='lightgray', pos=pos, ax=ax, arrowsize=20) 
```</details> ![../_images/818a3cf499d15c3be1d4c116db142da0418c174873f21e1ffcde679c6058f918.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/1adb771731c921aaf44122c0c8a2c96f.png)
这个格是促进表中信息的紧凑编码。您可以通过跟踪到两个节点的第一个共同子节点(包括节点本身)找到两个输入的类型提升的结果;在数学上,这个共同子节点被称为对格上的*上确界*,或*最小上界*,或*结合*的操作;这里我们将这个操作称为**结合**。
概念上,箭头表示允许在源和目标之间进行*隐式类型提升*:例如,允许从整数到浮点数的隐式提升,但不允许从浮点数到整数的隐式提升。
请记住,通常并非每个有向无环图(DAG)都满足格的性质。格要求每对节点之间存在唯一的最小上界;例如,以下两个 DAG 不是格:
<details class="hide above-input"><summary aria-label="Toggle hidden content">显示代码单元格源码 隐藏代码单元格源码</summary>
```py
#@title
import networkx as nx
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10, 2))
lattice = {'A': ['B', 'C']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0], 'B': [1, 0.5], 'C': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[0], arrowsize=20)
ax[0].set(xlim=[-0.5, 1.5], ylim=[-1, 1])
lattice = {'A': ['C', 'D'], 'B': ['C', 'D']}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {'A': [0, 0.5], 'B': [0, -0.5], 'C': [1, 0.5], 'D': [1, -0.5]}
nx.draw(graph, with_labels=True, node_size=2000, node_color='lightgray', pos=pos, ax=ax[1], arrowsize=20)
ax[1].set(xlim=[-0.5, 1.5], ylim=[-1, 1]); 
```</details> ![../_images/a0acbd07f9486d95c10a36c11301d528fb7e65d671d622226151c431b3e36c62.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/5a266a4810ed90d79776de9034ad3c61.png)
左边的 DAG 不是格,因为节点`B`和`C`没有上界;右边的 DAG 有两个问题:首先,节点`C`和`D`没有上界,其次,节点`A`和`B`的最小上界无法*唯一*确定:`C`和`D`都是候选项,但它们是不可排序的。
### 类型提升格的属性
在格中指定类型提升确保了许多有用的属性。用\(\vee\)运算符表示格中的结合,我们有:
**存在性:** 格的定义要求每对元素都存在唯一的格结合:\(\forall (a, b): \exists !(a \vee b)\)
**交换律:** 格的结合运算是交换的:\(\forall (a, b): a\vee b = b \vee a\).
**结合律:** 格的结合运算是结合的:\(\forall (a, b, c): a \vee (b \vee c) = (a \vee b) \vee c\).
另一方面,这些属性意味着它们对能够表示的类型提升系统有所限制;特别是**并非每个类型提升表都可以用格表示**。NumPy 的完整类型提升表就是一个快速反例:这里有三种标量类型,它们在 NumPy 中的提升行为是非结合的。
```py
import numpy as np
a, b, c = np.int8(1), np.uint8(1), np.float16(1)
print(np.dtype((a + b) + c))
print(np.dtype(a + (b + c))) 
float32
float16 

这样的结果可能会让用户感到惊讶:我们通常期望数学表达式映射到数学概念,所以,例如,a + b + c应等同于c + b + ax * (y + z)应等同于x * y + x * z。如果类型提升不是结合的或不是交换的,这些属性将不再适用。

此外,基于格子的类型提升系统与基于表的系统相比,在概念上更简单和更易理解。例如,JAX 识别 18 种不同的类型:一个包含 18 个节点和之间稀疏、有充分动机的连接的提升格子,比 324 个条目的表在脑中更容易维持。

因此,我们选择为 JAX 使用基于格子的类型提升系统。

类别内的类型提升

数值计算库通常提供不仅仅是intfloatcomplex,在每个类别中,都有各种可能的精度,由数值表示中使用的位数表示。我们在这里考虑的类别是:

  • 无符号整数,包括uint8uint16uint32uint64(我们简称为u8u16u32u64
  • 有符号整数,包括int8int16int32int64(我们简称为i8i16i32i64
  • 浮点数,包括float16float32float64(我们简称为f16f32f64
  • 复数浮点数,包括complex64complex128(我们简称为c64c128

Numpy 在每个这四个类别内的类型提升语义相对来说是相对简单的:类型的有序层次结构直接转换为四个分离的格子,表示类内类型提升规则:

显示代码单元源代码 隐藏代码单元源代码

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u8': [0, 0], 'u16': [1, 0], 'u32': [2, 0], 'u64': [3, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [1, 2], 'f32': [2, 2], 'f64': [3, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax) 
```</details> ![../_images/2d8495bcb006c34b42eeb4f3e0c6530fdef0bd7364c56184993925f0cf157abc.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/3704ee4a86ce603a27b8fdb41d064d81.png)
关于 JAX 避免的值提升为 64 位,这些同类别的提升语义在每种类型类别内部是没有问题的:产生 64 位输出的唯一方式是有一个 64 位输入。
## 输入 Python 标量
现在让我们考虑 Python 标量如何融入其中。
在 NumPy 中,提升行为取决于输入是数组还是标量。例如,在操作两个标量时,适用正常的提升规则:
```py
x = np.int8(0)  # int8 scalar
y = 1  # Python int = int64 scalar
(x + y).dtype 
dtype('int64') 

在这里,Python 值1被视为int64,并且简单的类内规则导致int64结果。

然而,在 Python 标量和 NumPy 数组之间的操作中,标量会延续到数组的 dtype。例如:

x = np.zeros(1, dtype='int8')  # int8 array
y = 1  # Python int = int64 scalar
(x + y).dtype 
dtype('int8') 

忽略int64标量的位宽度,而是延续数组的位宽度。

这里还有一个细节:当 NumPy 类型提升涉及标量时,输出的 dtype 取决于值:如果 Python 标量过大,超出了给定 dtype 的范围,则被提升为兼容的类型:

x = np.zeros(1, dtype='int8')  # int8 array
y = 1000  # int64 scalar
(x + y).dtype 
dtype('int16') 

出于 JAX 的目的,依赖值的提升是不可行的,因为 JIT 编译和其他转换的性质使其作用于数据的抽象表示,而不参考其值。

忽略依赖值的影响,NumPy 类型提升的有符号整数分支可以在以下格点中表示,我们将使用 * 标记标量数据类型:

显示代码单元格来源 隐藏代码单元格来源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i8*': ['i16*'], 'i16*': ['i32*'], 'i32*': ['i64*'], 'i64*': ['i8'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i8*': [0, 1], 'i16*': [2, 1], 'i32*': [4, 1], 'i64*': [6, 1],
  'i8': [9, 1], 'i16': [11, 1], 'i32': [13, 1], 'i64': [15, 1],
}
fig, ax = plt.subplots(figsize=(12, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax)
ax.text(3, 1.6, "Scalar Types", ha='center', fontsize=14)
ax.text(12, 1.6, "Array Types", ha='center', fontsize=14)
ax.set_ylim(-1, 3); 
```</details> ![../_images/7e8c3295e403209560d8e142c5c830d79456a4e6d207dd1a7e4d15b55c56006b.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/e129510e5b34d2fc6197a149b22de27c.png)
在 `uint`、`float` 和 `complex` 格点内,类似的模式也成立。
为了简单起见,让我们将每个标量类型的类别折叠为单个节点,分别表示为 `u*`、`i*`、`f*` 和 `c*`。我们的类别内格点集现在可以这样表示:
<details class="hide above-input"><summary aria-label="Toggle hidden content">显示代码单元格来源 隐藏代码单元格来源</summary>
```py
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'u*': ['u8'], 'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i*': ['i8'], 'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f*': ['f16'], 'f16': ['f32'], 'f32': ['f64'],
  'c*': ['c64'], 'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'u*': [0, 0], 'u8': [3, 0], 'u16': [5, 0], 'u32': [7, 0], 'u64': [9, 0],
  'i*': [0, 1], 'i8': [3, 1], 'i16': [5, 1], 'i32': [7, 1], 'i64': [9, 1],
  'f*': [0, 2], 'f16': [5, 2], 'f32': [7, 2], 'f64': [9, 2],
  'c*': [0, 3], 'c64': [7, 3], 'c128': [9, 3],
}
fig, ax = plt.subplots(figsize=(6, 4))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax) 
```</details> ![../_images/0fbe0c20cd350821e64f3742aa7864ec729565572b136950042095881672fdb9.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/f1e06280fcda633736c3159251434cfc.png)
从某种意义上说,将标量放在左边是一个奇怪的选择:标量类型可能包含任何宽度的值,但与给定类型的数组交互时,提升的结果将延续到数组类型。这样做的好处在于,当您对数组 `x` 执行像 `x + 2` 这样的操作时,`x` 的类型将传递到结果中,无论其宽度如何:
```py
for dtype in [np.int8, np.int16, np.int32, np.int64]:
  x = np.arange(10, dtype=dtype)
  assert (x + 2).dtype == dtype 

这种行为为标量值的 * 符号赋予了动机:* 符号类似于一个通配符,可以取任意所需的值。

这种语义的好处在于,您可以用清晰的 Python 代码轻松表达操作序列,而无需显式地将标量强制转换为适当的类型。想象一下,如果不是写成这样:

3 * (x + 1) ** 2 

您不得不写成这样:

np.int32(3) * (x + np.int32(1)) ** np.int32(2) 

尽管它很明确,数值代码会变得阅读或编写起来非常繁琐。使用上述标量提升语义,给定类型为 int32 的数组 x,第二个语句中的类型在第一个语句中是隐含的。

合并格点

请回想,我们开始讨论 Python 内部类型提升的格点图:int -> float -> complex。让我们将其重写为 i* -> f* -> c*,并允许 i* 吸收 u*(毕竟,在 Python 中没有无符号整数标量类型)。

将所有内容整合在一起,我们得到以下部分格点图,表示 Python 标量和 numpy 数组之间的类型提升:

显示代码单元格来源 隐藏代码单元格来源

#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax) 
```</details> ![../_images/796586be87180b0de3171d39763f2d33a80a641b72d82c00f0c0e352f754f201.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/dd4a0cdc7416bcb8469bfa5424566191.png)
注意,这还不是一个真正的格:存在许多节点对,它们之间没有联接。然而,我们可以将其视为一个*部分*格,在这种格中,某些节点对没有定义的推广行为,而此部分格的定义部分确实正确描述了 NumPy 的数组推广行为(不考虑上述值依赖语义)。
这为我们提供了一个很好的框架,可以用来思考如何填补这些未定义的推广规则,方法是在这个图上添加连接。但是应该添加哪些连接呢?总体来说,我们希望任何额外的连接都满足几个属性:
1.  推广应满足交换和结合性质:换句话说,图应保持(部分)格的形式。
1.  推广不应允许丢弃数据的整个组成部分:例如,我们不应将`complex`推广为`float`,因为这会丢弃任何虚部。
1.  推广不应导致未处理的溢出。例如,最大可能的`uint32`是最大可能的`int32`的两倍,因此我们不应隐式地将`uint32`提升为`int32`。
1.  在可能的情况下,推广应避免精度损失。例如,一个`int64`值可能有 64 位的尾数,因此将`int64`提升为`float64`可能会导致精度损失。然而,最大可表示的 float64 大于最大可表示的 int64,因此在这种情况下仍满足标准 #3。
1.  在可能的情况下,二进制推广应避免导致比输入更宽的类型。这是为了确保 JAX 的隐式推广对加速器工作流友好,其中用户通常希望将类型限制为 32 位(或在某些情况下是 16 位)值。
格上的每一个新连接都为用户引入了一定程度的便利性(一组新的可以在没有显式转换的情况下相互作用的类型),但是如果以上任何标准被违反,这种便利性可能会变得代价高昂。发展一个完整的推广格涉及在便利性和成本之间达到平衡。
## 混合推广:浮点数和复数
让我们从可能是最简单的情况开始,即在浮点数和复数值之间的推广。
复数由一对浮点数组成,因此在它们之间存在一种自然的推广路径:将浮点数转换为复数,同时保持实部的宽度。在我们的部分格表示中,它看起来像这样:
<details class="hide above-input"><summary aria-label="Toggle hidden content">显示代码单元格源码 隐藏代码单元格源码</summary>
```py
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16'], 'u16': ['u32'], 'u32': ['u64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax) 
```</details> ![../_images/bf87909b2344aed80590d1c6d91585a02b25898ac217526cb49948d91205318f.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/5d610fde93b793459425e06d94094f8f.png)
这恰好代表了 Numpy 在混合浮点/复数类型推广中使用的语义。
## 混合推广:有符号和无符号整数
接下来的情况,让我们考虑一些更困难的情况:有符号和无符号整数之间的提升。例如,当将`uint8`提升为有符号整数时,我们需要多少位?
乍一看,您可能会认为将`uint8`提升为`int8`是很自然的;但最大的`uint8`数字在`int8`中是不能表示的。因此,将无符号整数提升为比特数加倍的整数更有意义;这种提升行为可以通过将以下连接添加到提升格中来表示:
<details class="hide above-input"><summary aria-label="Toggle hidden content">显示代码单元格源代码 隐藏代码单元格源代码</summary>
```py
#@title
import networkx as nx
import matplotlib.pyplot as plt
lattice = {
  'i*': ['f*', 'u8', 'i8'], 'f*': ['c*', 'f16'], 'c*': ['c64'],
  'u8': ['u16', 'i16'], 'u16': ['u32', 'i32'], 'u32': ['u64', 'i64'],
  'i8': ['i16'], 'i16': ['i32'], 'i32': ['i64'],
  'f16': ['f32'], 'f32': ['f64', 'c64'], 'f64': ['c128'],
  'c64': ['c128']
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
  'i*': [-1.25, 0.5], 'f*': [-0.5, 2], 'c*': [0, 3],
  'u8': [0.5, 0], 'u16': [1.5, 0], 'u32': [2.5, 0], 'u64': [3.5, 0],
  'i8': [0, 1], 'i16': [1, 1], 'i32': [2, 1], 'i64': [3, 1],
  'f16': [0.5, 2], 'f32': [1.5, 2], 'f64': [2.5, 2],
  'c64': [2, 3], 'c128': [3, 3],
}
fig, ax = plt.subplots(figsize=(6, 5))
nx.draw(graph, with_labels=True, node_size=1500, node_color='lightgray', pos=pos, ax=ax) 
```</details> ![../_images/3be7e17889458ac823bb5dacf31525c0d96578c6854962f45dcc60ec987a30bd.png](https://gitee.com/OpenDocCN/dsai-docs-zh/raw/master/docs/jax/img/1a7d3b78b45858ca77d9810e77d053b9.png)
同样,这里添加的连接正是 NumPy 用于混合整数提升的提升语义实现。
### 如何处理`uint64`?
混合有符号/无符号整数提升的方法中缺少一种类型:`uint64`。按照上述模式,涉及`uint64`的混合整数操作的输出应该是`int128`,但这不是标准可用的数据类型。
NumPy 在这里的选择是提升为`float64`:
```py
(np.uint64(1) + np.int64(1)).dtype 
dtype('float64') 

然而,这可能是一个令人惊讶的约定:这是唯一一种整数类型提升不会产生整数的情况。目前,我们将保持uint64提升的未定义状态,并稍后再回到这个问题。


JAX 中文文档(十一)(5)https://developer.aliyun.com/article/1559784

相关文章
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
34 3
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
14 1
|
3月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
40 1
|
3月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
30 1
|
3月前
|
安全 算法 API
JAX 中文文档(十一)(3)
JAX 中文文档(十一)
23 0
|
3月前
|
自然语言处理 Shell PyTorch
JAX 中文文档(十一)(2)
JAX 中文文档(十一)
22 0
|
3月前
|
机器学习/深度学习 分布式计算 程序员
JAX 中文文档(十一)(1)
JAX 中文文档(十一)
29 0
|
3月前
|
IDE API 开发工具
JAX 中文文档(十二)(2)
JAX 中文文档(十二)
24 0