JAX 中文文档(八)(3)

简介: JAX 中文文档(八)

JAX 中文文档(八)(2)https://developer.aliyun.com/article/1559663


JAX 可转换的 Python 函数的自定义导数规则

原文:jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

[外链图片转存中…(img-Tun1SV4O-1718950304061)]

mattjj@ Mar 19 2020, last updated Oct 14 2020

JAX 中定义微分规则的两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 来为已经可转换为 JAX 的 Python 函数定义自定义微分规则;以及
  2. 定义新的 core.Primitive 实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或一般数值计算系统)的函数。

本笔记本讨论的是 #1. 要了解关于 #2 的信息,请参阅关于添加原语的笔记本

关于 JAX 自动微分 API 的介绍,请参阅自动微分手册。本笔记本假定读者已对jax.jvpjax.grad,以及 JVPs 和 VJPs 的数学含义有一定了解。

TL;DR

使用 jax.custom_jvp 进行自定义 JVPs

import jax.numpy as jnp
from jax import custom_jvp
@custom_jvp
def f(x, y):
  return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out 
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
  return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) 
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 

使用 jax.custom_vjp 进行自定义 VJPs

from jax import custom_vjp
@custom_vjp
def f(x, y):
  return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd) 
print(grad(f)(2., 3.)) 
-1.2484405 

示例问题

要了解 jax.custom_jvpjax.custom_vjp 所解决的问题,我们可以看几个例子。有关 jax.custom_jvpjax.custom_vjp API 的更详细介绍在下一节中。

数值稳定性

jax.custom_jvp 的一个应用是提高微分的数值稳定性。

假设我们想编写一个名为 log1pexp 的函数,用于计算 (x \mapsto \log ( 1 + e^x ))。我们可以使用 jax.numpy 来写:

import jax.numpy as jnp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
log1pexp(3.) 
Array(3.0485873, dtype=float32, weak_type=True) 

因为它是用 jax.numpy 编写的,所以它是 JAX 可转换的:

from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

但这里存在一个数值稳定性问题:

print(grad(log1pexp)(100.)) 
nan 

那似乎不对!毕竟,(x \mapsto \log (1 + e^x)) 的导数是 (x \mapsto \frac{e^x}{1 + e^x}),因此对于大的 (x) 值,我们期望值约为 1。

通过查看梯度计算的 jaxpr,我们可以更深入地了解发生了什么:

from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.) 
{ lambda ; a:f32[]. let
    b:f32[] = exp a
    c:f32[] = add 1.0 b
    _:f32[] = log c
    d:f32[] = div 1.0 c
    e:f32[] = mul d b
  in (e,) } 

通过分析 jaxpr 如何评估,我们可以看到最后一行涉及的值相乘会导致浮点数计算四舍五入为 0 和 (\infty),这从未是一个好主意。也就是说,我们实际上在评估大数值的情况下,计算的是 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x),这实际上会变成 0. * jnp.inf

而不是生成这样大和小的值,希望浮点数能够提供的取消,我们宁愿将导数函数表达为一个更稳定的数值程序。特别地,我们可以编写一个程序,更接近地评估相等的数学表达式 (1 - \frac{1}{1 + e^x}),看不到取消。

这个问题很有趣,因为即使我们的log1pexp的定义已经可以进行 JAX 微分(并且可以使用jitvmap等转换),我们对应用标准自动微分规则到组成log1pexp并组合结果的结果并不满意。相反,我们想要指定整个函数log1pexp如何作为一个单位进行微分,从而更好地安排这些指数。

这是关于 Python 函数的自定义导数规则的一个应用,这些函数已经可以使用 JAX 进行转换:指定如何对复合函数进行微分,同时仍然使用其原始的 Python 定义进行其他转换(如jitvmap等)。

这里是使用jax.custom_jvp的解决方案:

from jax import custom_jvp
@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = log1pexp(x)
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
  return ans, ans_dot 
print(grad(log1pexp)(100.)) 
1.0 
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

这里是一个defjvps方便包装,来表达同样的事情:

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t) 
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
1.0
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

强制执行微分约定

一个相关的应用是强制执行微分约定,也许在边界处。

考虑函数 (f : \mathbb{R}+ \to \mathbb{R}+),其中 (f(x) = \frac{x}{1 + \sqrt{x}}),其中我们取 (\mathbb{R}_+ = [0, \infty))。我们可以像这样实现 (f) 的程序:

def f(x):
  return x / (1 + jnp.sqrt(x)) 

作为在(\mathbb{R})上的数学函数(完整的实数线),(f) 在零点是不可微的(因为从左侧定义导数的极限不存在)。相应地,自动微分产生一个nan值:

print(grad(f)(0.)) 
nan 

但是数学上,如果我们将 (f) 视为 (\mathbb{R}_+) 上的函数,则它在 0 处是可微的 [Rudin 的《数学分析原理》定义  5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和例子  10.1.6]。或者,我们可能会说,作为一个惯例,我们希望考虑从右边的方向导数。因此,对于 Python 函数grad(f)0.0处返回 1.0 是有意义的值。默认情况下,JAX 对微分的机制假设所有函数在(\mathbb{R})上定义,因此这里并不会产生1.0

我们可以使用自定义的 JVP 规则!特别地,我们可以定义 JVP 规则,关于导数函数 (x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)²}) 在 (\mathbb{R}_+) 上,

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
  return ans, ans_dot 
print(grad(f)(0.)) 
1.0 

这里是方便包装版本:

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t) 
print(grad(f)(0.)) 
1.0 

梯度剪裁

虽然在某些情况下,我们想要表达一个数学微分计算,在其他情况下,我们甚至可能想要远离数学,来调整自动微分的计算。一个典型的例子是反向模式梯度剪裁。

对于梯度剪裁,我们可以使用jnp.clip和一个jax.custom_vjp仅逆模式规则:

from functools import partial
from jax import custom_vjp
@custom_vjp
def clip_gradient(lo, hi, x):
  return x  # identity function
def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save bounds as residuals
def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) 
import matplotlib.pyplot as plt
from jax import vmap
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43dfc210f0>] 

def clip_sin(x):
  x = clip_gradient(-0.75, 0.75, x)
  return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43ddb15fc0>] 

Python 调试

另一个应用,是受开发工作流程而非数值驱动的动机,是在反向模式自动微分的后向传递中设置pdb调试器跟踪。

在尝试追踪nan运行时错误的来源,或者仅仔细检查传播的余切(梯度)值时,可以在反向传递中的特定点插入调试器非常有用。您可以使用jax.custom_vjp来实现这一点。

我们将在下一节中推迟一个示例。


JAX 中文文档(八)(4)https://developer.aliyun.com/article/1559665

相关文章
|
3月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
22 1
|
3月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
26 0
JAX 中文文档(六)(1)
|
3月前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
59 0
|
3月前
|
机器学习/深度学习 索引 Python
JAX 中文文档(四)(1)
JAX 中文文档(四)
29 0
|
3月前
|
并行计算 Linux 异构计算
JAX 中文文档(一)(1)
JAX 中文文档(一)
83 0
|
3月前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
22 0
|
3月前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
23 0
|
3月前
|
机器学习/深度学习 并行计算 安全
JAX 中文文档(七)(1)
JAX 中文文档(七)
34 0
|
3月前
|
API 索引 Python
JAX 中文文档(三)(4)
JAX 中文文档(三)
18 0
|
3月前
|
机器学习/深度学习
JAX 中文文档(六)(5)
JAX 中文文档(六)
26 0