JAX 中文文档(八)(1)

简介: JAX 中文文档(八)


原文:jax.readthedocs.io/en/latest/

自动微分手册

原文:jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

alexbw@, mattjj@

JAX 拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0) 

梯度

grad开始

您可以使用grad对函数进行微分:

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) 
0.070650816 

grad接受一个函数并返回一个函数。如果您有一个评估数学函数 ( f ) 的 Python 函数 f,那么 grad(f) 是一个评估数学函数 ( \nabla f ) 的 Python 函数。这意味着 grad(f)(x) 表示值 ( \nabla f(x) )。

由于grad操作函数,您可以将其应用于其自身的输出以多次进行微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0)) 
-0.13621868
0.25265405 

让我们看看如何在线性逻辑回归模型中使用grad计算梯度。首先,设置:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ()) 

使用argnums参数的grad函数来相对于位置参数微分函数。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad) 
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245 

grad API 直接对应于 Spivak 经典著作Calculus on Manifolds(1965)中的优秀符号,也用于 Sussman 和 Wisdom 的Structure and Interpretation of Classical Mechanics(2015)及其Functional Differential Geometry(2013)。这两本书都是开放获取的。特别是参见Functional Differential Geometry的“序言”部分,以了解此符号的辩护。

当使用argnums参数时,如果f是一个用于计算数学函数 ( f ) 的 Python 函数,则 Python 表达式grad(f, i)用于评估 ( \partial_i f ) 的 Python 函数。

相对于嵌套列表、元组和字典进行微分

使用标准的 Python 容器进行微分是完全有效的,因此可以随意使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b})) 
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)} 

您可以注册您自己的容器类型以便不仅与grad一起工作,还可以与所有 JAX 转换(jitvmap等)一起工作。

使用value_and_grad评估函数及其梯度

另一个方便的函数是value_and_grad,可以高效地计算函数值及其梯度值:

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b)) 
loss value 3.0519385
loss value 3.0519385 

与数值差分进行对比

导数的一个很好的特性是它们很容易用有限差分进行检查:

# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec)) 
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117 

JAX 提供了一个简单的便利函数,本质上执行相同的操作,但可以检查任何您喜欢的微分顺序:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives 

使用 grad-of-grad 进行 Hessian 向量乘积

使用高阶 grad 可以构建一个 Hessian 向量乘积函数。 (稍后我们将编写一个更高效的实现,该实现混合了前向和反向模式,但这个实现将纯粹使用反向模式。)

在最小化平滑凸函数的截断牛顿共轭梯度算法或研究神经网络训练目标的曲率(例如1234)中,Hessian 向量乘积函数非常有用。

对于一个标量值函数 ( f : \mathbb{R}^n \to \mathbb{R} ),具有连续的二阶导数(因此 Hessian  矩阵是对称的),点 ( x \in \mathbb{R}^n ) 处的 Hessian 被写为 (\partial²  f(x))。然后,Hessian 向量乘积函数能够评估

(\qquad v \mapsto \partial² f(x) \cdot v)

对于任意 ( v \in \mathbb{R}^n )。

窍门在于不要实例化完整的 Hessian 矩阵:如果 ( n ) 很大,例如在神经网络的背景下可能是百万或十亿级别,那么可能无法存储。

幸运的是,grad 已经为我们提供了一种编写高效的 Hessian 向量乘积函数的方法。我们只需使用下面的身份证

(\qquad \partial² f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)),

其中 ( g(x) = \partial f(x) \cdot v ) 是一个新的标量值函数,它将 ( f ) 在 ( x ) 处的梯度与向量 ( v ) 点乘。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad 高效的地方。

在 JAX 代码中,我们可以直接写成这样:

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) 

这个例子表明,您可以自由使用词汇闭包,而 JAX 绝不会感到不安或困惑。

一旦我们看到如何计算密集的 Hessian 矩阵,我们将在几个单元格下检查此实现。我们还将编写一个更好的版本,该版本同时使用前向模式和反向模式。

使用 jacfwdjacrev 计算 Jacobians 和 Hessians

您可以使用 jacfwdjacrev 函数计算完整的 Jacobian 矩阵:

from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J) 
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]] 

这两个函数计算相同的值(直到机器数学),但它们在实现上有所不同:jacfwd 使用前向模式自动微分,对于“高”的 Jacobian 矩阵更有效,而 jacrev 使用反向模式,对于“宽”的 Jacobian 矩阵更有效。对于接近正方形的矩阵,jacfwd 可能比 jacrev 有优势。

您还可以在容器类型中使用 jacfwdjacrev

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v) 
Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771] 

关于前向模式和反向模式的更多细节,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用两个这些函数的复合给我们一种计算密集的 Hessian 矩阵的方法:

def hessian(f):
    return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H) 
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]
 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]
 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]
 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]] 

这种形状是合理的:如果我们从一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m) 开始,那么在点 (x \in \mathbb{R}^n) 我们期望得到以下形状

  • (f(x) \in \mathbb{R}^m),在 (x) 处的 (f) 的值,
  • (\partial f(x) \in \mathbb{R}^{m \times n}),在 (x) 处的雅可比矩阵,
  • (\partial² f(x) \in \mathbb{R}^{m \times n \times n}),在 (x) 处的 Hessian 矩阵,

以及其他一些内容。

要实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或这两者的任何组合。但是前向超过反向通常是最有效的。这是因为在内部雅可比计算中,我们通常是在不同 iating  一个函数宽雅可比(也许像损失函数 (f : \mathbb{R}^n \to \mathbb{R})),而在外部雅可比计算中,我们是在不同  iating 具有方雅可比的函数(因为 (\nabla f : \mathbb{R}^n \to  \mathbb{R}^n)),这就是前向模式胜出的地方。

制造过程:两个基础的自动微分函数

雅可比-向量积(JVPs,也称为前向模式自动微分)

JAX 包括前向模式和反向模式自动微分的高效和通用实现。熟悉的 grad 函数建立在反向模式之上,但要解释两种模式的区别,以及每种模式何时有用,我们需要一些数学背景。

数学中的雅可比向量积

在数学上,给定一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m),在输入点 (x \in  \mathbb{R}^n) 处评估的雅可比矩阵 (\partial f(x)),通常被视为一个 (\mathbb{R}^m \times  \mathbb{R}^n) 中的矩阵:

(\qquad \partial f(x) \in \mathbb{R}^{m \times n}).

但我们也可以将 (\partial f(x)) 看作是一个线性映射,它将 (f) 的定义域在点 (x) 的切空间(即另一个  (\mathbb{R}^n) 的副本)映射到 (f) 的值域在点 (f(x)) 的切空间(一个 (\mathbb{R}^m) 的副本):

(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m).

此映射称为 (f) 在 (x) 处的推前映射。雅可比矩阵只是标准基中这个线性映射的矩阵。

如果我们不确定一个特定的输入点 (x),那么我们可以将函数 (\partial f) 视为首先接受一个输入点并返回该输入点处的雅可比线性映射:

(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m)。

特别是,我们可以解开事物,这样给定输入点 (x \in \mathbb{R}^n) 和切向量 (v \in \mathbb{R}^n),我们得到一个输出切向量在 (\mathbb{R}^m) 中。我们称这种映射,从 ((x, v)) 对到输出切向量,为雅可比向量积,并将其写为

(\qquad (x, v) \mapsto \partial f(x) v)

JAX 代码中的雅可比向量积

回到 Python 代码中,JAX 的 jvp 函数模拟了这种转换。给定一个评估 (f) 的 Python 函数,JAX 的 jvp 是获取评估 ((x, v) \mapsto (f(x), \partial f(x) v)) 的 Python 函数的一种方法。

from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,)) 

用类似 Haskell 的类型签名来说,我们可以写成

jvp  ::  (a  ->  b)  ->  a  ->  T  a  ->  (b,  T  b) 

在这里,我们使用 T a 来表示 a 的切线空间的类型。简言之,jvp 接受一个类型为 a -> b 的函数作为参数,一个类型为 a 的值,以及一个类型为 T a 的切线向量值。它返回一个由类型为 b 的值和类型为 T b 的输出切线向量组成的对。

jvp 转换后的函数的评估方式与原函数类似,但与每个类型为 a 的原始值配对时,它会沿着类型为 T a 的切线值进行推进。对于原始函数将应用的每个原始数值操作,jvp 转换后的函数会执行一个“JVP 规则”,该规则同时在这些原始值上评估原始数值,并应用其 JVP。

该评估策略对计算复杂度有一些直接影响:由于我们在进行评估时同时评估 JVP,因此我们不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp 转换后的函数的 FLOP 成本约为评估函数的成本的 3 倍(例如对于评估原始函数的一个单位工作,如 sin(x);一个单位用于线性化,如 cos(x);和一个单位用于将线性化函数应用于向量,如 cos_x * v)。换句话说,对于固定的原始点 (x),我们可以以大致相同的边际成本评估 (v \mapsto \partial f(x) \cdot v),如同评估 (f) 一样。

那么内存复杂度听起来非常有说服力!那为什么我们在机器学习中很少见到正向模式呢?

要回答这个问题,首先考虑如何使用 JVP 构建完整的 Jacobian 矩阵。如果我们将 JVP  应用于一个单位切线向量,它会显示出我们输入的非零条目对应的 Jacobian 矩阵的一列。因此,我们可以逐列地构建完整的 Jacobian  矩阵,获取每列的成本大约与一个函数评估相同。对于具有“高”Jacobian 的函数来说,这将是高效的,但对于“宽”Jacobian  来说则效率低下。

如果你在机器学习中进行基于梯度的优化,你可能想要最小化一个从 (\mathbb{R}^n) 中的参数到 (\mathbb{R})  中标量损失值的损失函数。这意味着这个函数的雅可比矩阵是一个非常宽的矩阵:(\partial f(x) \in \mathbb{R}^{1  \times n}),我们通常将其视为梯度向量 (\nabla f(x) \in  \mathbb{R}^n)。逐列构建这个矩阵,每次调用需要类似数量的浮点运算来评估原始函数,看起来确实效率低下!特别是对于训练神经网络,其中  (f) 是一个训练损失函数,而 (n) 可以是百万或十亿级别,这种方法根本不可扩展。

为了更好地处理这类函数,我们只需要使用反向模式。### 向量-雅可比积(VJPs,又称反向自动微分)

在前向模式中,我们得到了一个用于评估雅可比向量积的函数,然后我们可以使用它逐列构建雅可比矩阵;而反向模式则是一种获取用于评估向量-雅可比积(或等效地雅可比-转置向量积)的函数的方式,我们可以用它逐行构建雅可比矩阵。

数学中的 VJPs

再次考虑一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m)。从我们对 JVP 的表示开始,对于 VJP 的表示非常简单:

(\qquad (x, v) \mapsto v \partial f(x)),

其中 (v) 是在 (x) 处 (f) 的余切空间的元素(同构于另一个 (\mathbb{R}^m) 的副本)。在严格时,我们应该将  (v) 视为一个线性映射 (v : \mathbb{R}^m \to \mathbb{R}),当我们写 (v \partial f(x))  时,我们意味着函数复合 (v \circ \partial f(x)),其中类型之间的对应关系是因为 (\partial f(x) :  \mathbb{R}^n \to \mathbb{R}^m)。但在通常情况下,我们可以将 (v) 与 (\mathbb{R}^m)  中的一个向量等同看待,并几乎可以互换使用,就像有时我们可以在“列向量”和“行向量”之间轻松切换而不加过多评论一样。

有了这个认识,我们可以将 VJP 的线性部分看作是 JVP 线性部分的转置(或共轭伴随):

(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v).

对于给定点 (x),我们可以将签名写为

(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n).

对应的余切空间映射通常称为[拉回](https://en.wikipedia.org/wiki/Pullback_(differential_geometry))  (f) 在 (x) 处的。对我们而言,关键在于它从类似 (f) 输出的东西到类似 (f) 输入的东西,就像我们从一个转置线性函数所期望的那样。

JAX 代码中的 VJPs

从数学切换回 Python,JAX 函数 vjp 可以接受一个用于评估 (f) 的 Python 函数,并给我们返回一个用于评估 VJP ((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))) 的 Python 函数。

from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u) 

类似 Haskell 类型签名的形式来说,我们可以写成

vjp  ::  (a  ->  b)  ->  a  ->  (b,  CT  b  ->  CT  a) 

在这里,我们使用 CT a 表示 a 的余切空间的类型。换句话说,vjp 接受类型为 a -> b 的函数和类型为 a 的点作为参数,并返回一个由类型为 b 的值和类型为 CT b -> CT a 的线性映射组成的对。

这很棒,因为它让我们一次一行地构建雅可比矩阵,并且评估 ((x, v) \mapsto (f(x), v^\mathsf{T}  \partial f(x))) 的 FLOP 成本仅约为评估 (f) 的三倍。特别是,如果我们想要函数 (f : \mathbb{R}^n  \to \mathbb{R}) 的梯度,我们可以一次性完成。这就是 grad 对基于梯度的优化非常高效的原因,即使是对于数百万或数十亿个参数的神经网络训练损失函数这样的目标。

这里有一个成本:虽然 FLOP 友好,但内存随计算深度而增加。而且,该实现在传统上比前向模式更为复杂,但 JAX 对此有一些窍门(这是未来笔记本的故事!)。

关于反向模式的工作原理,可以查看2017 年深度学习暑期学校的教程视频

使用 VJPs 的矢量值梯度

如果你对使用矢量值梯度(如 tf.gradients)感兴趣:

from jax import vjp
def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) 
[[6\. 6.]
 [6\. 6.]] 


JAX 中文文档(八)(2)https://developer.aliyun.com/article/1559663

相关文章
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
31 1
|
3月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
33 1
|
3月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
26 0
JAX 中文文档(六)(1)
|
3月前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
23 0
|
3月前
|
存储 安全 API
JAX 中文文档(十)(2)
JAX 中文文档(十)
33 0
|
3月前
|
机器学习/深度学习
JAX 中文文档(六)(5)
JAX 中文文档(六)
24 0
|
3月前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
19 0
|
3月前
|
缓存 Serverless API
JAX 中文文档(十)(4)
JAX 中文文档(十)
25 0
|
3月前
|
存储 机器学习/深度学习 TensorFlow
JAX 中文文档(七)(5)
JAX 中文文档(七)
23 0
|
3月前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
18 0