JAX 中文文档(八)(2)https://developer.aliyun.com/article/1559663
JAX 可转换的 Python 函数的自定义导数规则
原文:
jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
[外链图片转存中…(img-Tun1SV4O-1718950304061)]
mattjj@ Mar 19 2020, last updated Oct 14 2020
JAX 中定义微分规则的两种方式:
- 使用
jax.custom_jvp
和jax.custom_vjp
来为已经可转换为 JAX 的 Python 函数定义自定义微分规则;以及 - 定义新的
core.Primitive
实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或一般数值计算系统)的函数。
本笔记本讨论的是 #1. 要了解关于 #2 的信息,请参阅关于添加原语的笔记本。
关于 JAX 自动微分 API 的介绍,请参阅自动微分手册。本笔记本假定读者已对jax.jvp和jax.grad,以及 JVPs 和 VJPs 的数学含义有一定了解。
TL;DR
使用 jax.custom_jvp
进行自定义 JVPs
import jax.numpy as jnp from jax import custom_jvp @custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
from jax import jvp, grad print(f(2., 3.)) y, y_dot = jvp(f, (2., 3.), (1., 0.)) print(y) print(y_dot) print(grad(f)(2., 3.))
2.7278922 2.7278922 -1.2484405 -1.2484405
# Equivalent alternative using the defjvps convenience wrapper @custom_jvp def f(x, y): return jnp.sin(x) * y f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.)) y, y_dot = jvp(f, (2., 3.), (1., 0.)) print(y) print(y_dot) print(grad(f)(2., 3.))
2.7278922 2.7278922 -1.2484405 -1.2484405
使用 jax.custom_vjp
进行自定义 VJPs
from jax import custom_vjp @custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): # Returns primal output and residuals to be used in backward pass by f_bwd. return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res # Gets residuals computed in f_fwd return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
示例问题
要了解 jax.custom_jvp
和 jax.custom_vjp
所解决的问题,我们可以看几个例子。有关 jax.custom_jvp
和 jax.custom_vjp
API 的更详细介绍在下一节中。
数值稳定性
jax.custom_jvp
的一个应用是提高微分的数值稳定性。
假设我们想编写一个名为 log1pexp
的函数,用于计算 (x \mapsto \log ( 1 + e^x ))。我们可以使用 jax.numpy
来写:
import jax.numpy as jnp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
因为它是用 jax.numpy
编写的,所以它是 JAX 可转换的:
from jax import jit, grad, vmap print(jit(log1pexp)(3.)) print(jit(grad(log1pexp))(3.)) print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873 0.95257413 [0.5 0.7310586 0.8807971]
但这里存在一个数值稳定性问题:
print(grad(log1pexp)(100.))
nan
那似乎不对!毕竟,(x \mapsto \log (1 + e^x)) 的导数是 (x \mapsto \frac{e^x}{1 + e^x}),因此对于大的 (x) 值,我们期望值约为 1。
通过查看梯度计算的 jaxpr,我们可以更深入地了解发生了什么:
from jax import make_jaxpr make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let b:f32[] = exp a c:f32[] = add 1.0 b _:f32[] = log c d:f32[] = div 1.0 c e:f32[] = mul d b in (e,) }
通过分析 jaxpr 如何评估,我们可以看到最后一行涉及的值相乘会导致浮点数计算四舍五入为 0 和 (\infty),这从未是一个好主意。也就是说,我们实际上在评估大数值的情况下,计算的是 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
,这实际上会变成 0. * jnp.inf
。
而不是生成这样大和小的值,希望浮点数能够提供的取消,我们宁愿将导数函数表达为一个更稳定的数值程序。特别地,我们可以编写一个程序,更接近地评估相等的数学表达式 (1 - \frac{1}{1 + e^x}),看不到取消。
这个问题很有趣,因为即使我们的log1pexp
的定义已经可以进行 JAX 微分(并且可以使用jit
、vmap
等转换),我们对应用标准自动微分规则到组成log1pexp
并组合结果的结果并不满意。相反,我们想要指定整个函数log1pexp
如何作为一个单位进行微分,从而更好地安排这些指数。
这是关于 Python 函数的自定义导数规则的一个应用,这些函数已经可以使用 JAX 进行转换:指定如何对复合函数进行微分,同时仍然使用其原始的 Python 定义进行其他转换(如jit
、vmap
等)。
这里是使用jax.custom_jvp
的解决方案:
from jax import custom_jvp @custom_jvp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) @log1pexp.defjvp def log1pexp_jvp(primals, tangents): x, = primals x_dot, = tangents ans = log1pexp(x) ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.)) print(jit(grad(log1pexp))(3.)) print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873 0.95257413 [0.5 0.7310586 0.8807971]
这里是一个defjvps
方便包装,来表达同样的事情:
@custom_jvp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.)) print(jit(log1pexp)(3.)) print(jit(grad(log1pexp))(3.)) print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0 3.0485873 0.95257413 [0.5 0.7310586 0.8807971]
强制执行微分约定
一个相关的应用是强制执行微分约定,也许在边界处。
考虑函数 (f : \mathbb{R}+ \to \mathbb{R}+),其中 (f(x) = \frac{x}{1 + \sqrt{x}}),其中我们取 (\mathbb{R}_+ = [0, \infty))。我们可以像这样实现 (f) 的程序:
def f(x): return x / (1 + jnp.sqrt(x))
作为在(\mathbb{R})上的数学函数(完整的实数线),(f) 在零点是不可微的(因为从左侧定义导数的极限不存在)。相应地,自动微分产生一个nan
值:
print(grad(f)(0.))
nan
但是数学上,如果我们将 (f) 视为 (\mathbb{R}_+) 上的函数,则它在 0 处是可微的 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和例子 10.1.6]。或者,我们可能会说,作为一个惯例,我们希望考虑从右边的方向导数。因此,对于 Python 函数grad(f)
在0.0
处返回 1.0 是有意义的值。默认情况下,JAX 对微分的机制假设所有函数在(\mathbb{R})上定义,因此这里并不会产生1.0
。
我们可以使用自定义的 JVP 规则!特别地,我们可以定义 JVP 规则,关于导数函数 (x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)²}) 在 (\mathbb{R}_+) 上,
@custom_jvp def f(x): return x / (1 + jnp.sqrt(x)) @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents ans = f(x) ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot return ans, ans_dot
print(grad(f)(0.))
1.0
这里是方便包装版本:
@custom_jvp def f(x): return x / (1 + jnp.sqrt(x)) f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
梯度剪裁
虽然在某些情况下,我们想要表达一个数学微分计算,在其他情况下,我们甚至可能想要远离数学,来调整自动微分的计算。一个典型的例子是反向模式梯度剪裁。
对于梯度剪裁,我们可以使用jnp.clip
和一个jax.custom_vjp
仅逆模式规则:
from functools import partial from jax import custom_vjp @custom_vjp def clip_gradient(lo, hi, x): return x # identity function def clip_gradient_fwd(lo, hi, x): return x, (lo, hi) # save bounds as residuals def clip_gradient_bwd(res, g): lo, hi = res return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt from jax import vmap t = jnp.linspace(0, 10, 1000) plt.plot(jnp.sin(t)) plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7f43dfc210f0>]
def clip_sin(x): x = clip_gradient(-0.75, 0.75, x) return jnp.sin(x) plt.plot(clip_sin(t)) plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7f43ddb15fc0>]
Python 调试
另一个应用,是受开发工作流程而非数值驱动的动机,是在反向模式自动微分的后向传递中设置pdb
调试器跟踪。
在尝试追踪nan
运行时错误的来源,或者仅仔细检查传播的余切(梯度)值时,可以在反向传递中的特定点插入调试器非常有用。您可以使用jax.custom_vjp
来实现这一点。
我们将在下一节中推迟一个示例。
JAX 中文文档(八)(4)https://developer.aliyun.com/article/1559665