JAX 中文文档(八)(2)

简介: JAX 中文文档(八)

JAX 中文文档(八)(1)https://developer.aliyun.com/article/1559662


使用前向和反向模式的黑塞矢量积

在前面的部分中,我们仅使用反向模式实现了一个黑塞-矢量积函数(假设具有连续二阶导数):

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) 

这是高效的,但我们甚至可以更好地节省一些内存,通过使用前向模式和反向模式。

从数学上讲,给定一个要区分的函数 (f : \mathbb{R}^n \to \mathbb{R}),要线性化函数的一个点 (x \in \mathbb{R}^n),以及一个向量 (v \in \mathbb{R}^n),我们想要的黑塞-矢量积函数是

((x, v) \mapsto \partial² f(x) v)

考虑助手函数 (g : \mathbb{R}^n \to \mathbb{R}^n) 定义为 (f) 的导数(或梯度),即 (g(x) = \partial f(x))。我们所需的只是它的 JVP,因为这将给我们

((x, v) \mapsto \partial g(x) v = \partial² f(x) v).

我们几乎可以直接将其转换为代码:

from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1] 

更好的是,由于我们不需要直接调用 jnp.dot,这个 hvp 函数可以处理任何形状的数组以及任意的容器类型(如嵌套列表/字典/元组中存储的向量),甚至与jax.numpy 没有任何依赖。

这是如何使用它的示例:

def f(X):
  return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) 
True 

另一种你可能考虑写这个的方法是使用反向-前向模式:

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals) 

不过,这不是很好,因为前向模式的开销比反向模式小,由于外部区分算子要区分比内部更大的计算,将前向模式保持在外部是最好的:

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2) 
Forward over reverse
4.74 ms ± 157 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.46 ms ± 5.05 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
14.3 ms ± 7.71 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
57.7 ms ± 1.32 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) 

组成 VJP、JVP 和 vmap

雅可比-矩阵和矩阵-雅可比乘积

现在我们有jvpvjp变换,它们为我们提供了推送或拉回单个向量的函数,我们可以使用 JAX 的vmap 变换一次推送和拉回整个基。特别是,我们可以用它来快速编写矩阵-雅可比和雅可比-矩阵乘积。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' 
Non-vmapped Matrix-Jacobian product
168 ms ± 260 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
6.39 ms ± 49.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) 
/tmp/ipykernel_1379/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0\. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M]) 
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' 
Non-vmapped Jacobian-Matrix product
290 ms ± 437 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
3.29 ms ± 22.5 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) 

jacfwdjacrev的实现

现在我们已经看到了快速的雅可比-矩阵和矩阵-雅可比乘积,写出jacfwdjacrev并不难。我们只需使用相同的技术一次推送或拉回整个标准基(等同于单位矩阵)。

from jax import jacrev as builtin_jacrev
def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once. 
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' 
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' 

有趣的是,Autograd做不到这一点。我们在 Autograd 中反向模式jacobian实现必须逐个向量地拉回,使用外层循环map。逐个向量地通过计算远不及使用vmap一次将所有内容批处理高效。

另一件 Autograd 做不到的事情是jit。有趣的是,无论您在要进行微分的函数中使用多少 Python 动态性,我们总是可以在计算的线性部分上使用jit。例如:

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.)) 
(Array(3.1415927, dtype=float32, weak_type=True),) 

复数和微分

JAX 在复数和微分方面表现出色。为了支持全纯和非全纯微分,理解 JVP 和 VJP 很有帮助。

考虑一个复到复的函数 (f: \mathbb{C} \to \mathbb{C}) 并将其与相应的函数 (g: \mathbb{R}² \to \mathbb{R}²) 对应起来,

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j
def g(x, y):
  return (u(x, y), v(x, y)) 

也就是说,我们分解了 (f(z) = u(x, y) + v(x, y) i) 其中 (z = x + y i),并将 (\mathbb{C}) 与 (\mathbb{R}²) 对应起来得到了 (g)。

由于 (g) 只涉及实数输入和输出,我们已经知道如何为它编写雅可比-向量积,例如给定切向量 ((c, d) \in \mathbb{R}²),

([0u(x,y)1u(x,y) 0v(x,y)1v(x,y)] [c d]).

要获得应用于切向量 (c + di \in \mathbb{C}) 的原始函数 (f) 的 JVP,我们只需使用相同的定义,并将结果标识为另一个复数,

(\partial f(x + y i)(c + d i) = [1i]   [0u(x,y)1u(x,y) 0v(x,y)1v(x,y)] [c d]).

这就是我们对复到复函数 (f) 的 JVP 的定义!注意,无论 (f) 是否全纯,JVP 都是明确的。

这里是一个检查:

def check(seed):
  key = random.key(seed)
  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))
  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j
  def u(x, y):
    return a * x + b * y
  def v(x, y):
    return c * x + d * y
  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j
  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j
  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected)) 
check(0)
check(1)
check(2) 
True
True
True 

那么 VJP 呢?我们做了类似的事情:对于余切向量 (c + di \in \mathbb{C}),我们将 (f) 的 VJP 定义为

((c + di)^* ; \partial f(x + y i) = [cd]   [0u(x,y)1u(x,y) 0v(x,y)1v(x,y)] [1 i]).

为什么要有负号?这些只是为了处理复共轭,以及我们正在处理余切向量的事实。

这里是 VJP 规则的检查:

def check(seed):
  key = random.key(seed)
  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))
  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j
  def u(x, y):
    return a * x + b * y
  def v(x, y):
    return c * x + d * y
  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j
  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control
  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) 
check(0)
check(1)
check(2) 

方便的包装器如gradjacfwdjacrev有什么作用?

对于(\mathbb{R} \to \mathbb{R})函数,回想我们定义grad(f)(x)vjp(f, x)1,这是因为将 VJP 应用于1.0值会显示梯度(即雅可比矩阵或导数)。对于(\mathbb{C} \to \mathbb{R})函数,我们可以做同样的事情:我们仍然可以使用1.0作为余切向量,而我们得到的只是总结完整雅可比矩阵的一个复数结果:

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2
z = 3. + 4j
grad(f)(z) 
Array(6.-8.j, dtype=complex64) 

对于一般的(\mathbb{C} \to \mathbb{C})函数,雅可比矩阵有 4 个实值自由度(如上面的 2x2  雅可比矩阵),因此我们不能希望在一个复数中表示所有这些自由度。但对于全纯函数,我们可以!全纯函数恰好是一个(\mathbb{C} \to  \mathbb{C})函数,其导数可以表示为一个单一的复数。(柯西-黎曼方程确保上述 2x2 雅可比矩阵在复平面内的作用具有复数乘法下的比例和旋转矩阵的特殊形式。)我们可以使用一个vjp调用并带有1.0的余切向量来揭示那一个复数。

因为这仅适用于全纯函数,为了使用这个技巧,我们需要向 JAX 保证我们的函数是全纯的;否则,在复数输出函数上使用grad时,JAX 会引发错误:

def f(z):
  return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) 
Array(-27.034945-3.8511531j, dtype=complex64, weak_type=True) 

holomorphic=True的承诺仅仅是在输出是复数值时禁用错误。当函数不是全纯时,我们仍然可以写holomorphic=True,但得到的答案将不表示完整的雅可比矩阵。相反,它将是在我们只丢弃输出的虚部的函数的雅可比矩阵。

def f(z):
  return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic! 
Array(1.-0.j, dtype=complex64, weak_type=True) 

在这里grad的工作有一些有用的结论:

  1. 我们可以在全纯的(\mathbb{C} \to \mathbb{C})函数上使用grad
  2. 我们可以使用grad来优化(\mathbb{C} \to \mathbb{R})函数,例如复参数x的实值损失函数,通过朝着grad(f)(x)的共轭方向迈出步伐。
  3. 如果我们有一个(\mathbb{R} \to \mathbb{R})的函数,它恰好在内部使用一些复数运算(其中一些必须是非全纯的,例如在卷积中使用的 FFT),那么grad仍然有效,并且我们得到与仅使用实数值的实现相同的结果。

在任何情况下,JVPs 和 VJPs 都是明确的。如果我们想计算非全纯函数(\mathbb{C} \to \mathbb{C})的完整 Jacobian 矩阵,我们可以用 JVPs 或 VJPs 来做到!

你应该期望复数在 JAX 中的任何地方都能正常工作。这里是通过复杂矩阵的 Cholesky 分解进行微分:

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])
def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A) 
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940544j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64) 

更高级的自动微分

在这本笔记本中,我们通过一些简单的,然后逐渐复杂的应用中,使用 JAX 中的自动微分。我们希望现在您感觉在 JAX 中进行导数运算既简单又强大。

还有很多其他自动微分的技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分手册”中进行涵盖:

  • 高斯-牛顿向量乘积,一次线性化
  • 自定义的 VJPs 和 JVPs
  • 在固定点处高效地求导
  • 使用随机的 Hessian-vector products 来估计 Hessian 的迹。
  • 仅使用反向模式自动微分的前向模式自动微分。
  • 对自定义数据类型进行导数计算。
  • 检查点(二项式检查点用于高效的反向模式,而不是模型快照)。
  • 优化 VJPs 通过 Jacobian 预积累。


JAX 中文文档(八)(3)https://developer.aliyun.com/article/1559664

相关文章
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
31 1
|
3月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
27 0
|
3月前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
19 0
|
3月前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
25 0
|
3月前
|
缓存 Serverless API
JAX 中文文档(十)(4)
JAX 中文文档(十)
25 0
|
3月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
26 0
|
3月前
|
机器学习/深度学习 缓存 编译器
JAX 中文文档(二)(1)
JAX 中文文档(二)
45 0
|
3月前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
29 0
|
3月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
38 0
|
3月前
|
编译器 异构计算 索引
JAX 中文文档(五)(4)
JAX 中文文档(五)
53 0