JAX 中文文档(八)(3)https://developer.aliyun.com/article/1559664
迭代实现的隐式函数微分
这个例子涉及到了数学中的深层问题!
另一个应用jax.custom_vjp
是对可通过jit
、vmap
等转换为 JAX 但由于某些原因不易 JAX 可区分的函数进行反向模式微分,也许是因为涉及lax.while_loop
。(无法生成 XLA HLO 程序有效计算 XLA HLO While 循环的反向模式导数,因为这将需要具有无界内存使用的程序,这在 XLA HLO 中是不可能表达的,至少不是通过通过 infeed/outfeed 的副作用交互。)
例如,考虑这个fixed_point
例程,它通过在while_loop
中迭代应用函数来计算一个不动点:
from jax.lax import while_loop def fixed_point(f, a, x_guess): def cond_fun(carry): x_prev, x = carry return jnp.abs(x_prev - x) > 1e-6 def body_fun(carry): _, x = carry return x, f(a, x) _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) return x_star
这是一种通过迭代应用函数(x_{t+1} = f(a, x_t))来数值解方程(x = f(a, x))以计算(x)的迭代过程,直到(x_{t+1})足够接近(x_t)。结果(x^)取决于参数(a),因此我们可以认为存在一个由方程(x = f(a, x))隐式定义的函数(a \mapsto x^(a))。
我们可以使用fixed_point
运行迭代过程以收敛,例如运行牛顿法来计算平方根,只执行加法、乘法和除法:
def newton_sqrt(a): update = lambda a, x: 0.5 * (x + a / x) return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135
我们也可以对函数进行vmap
或jit
处理:
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1\. 1.4142135 1.7320509 2\. ]
由于while_loop
,我们无法应用反向模式自动微分,但事实证明我们也不想这样做:我们可以利用数学结构做一些更节省内存(在这种情况下也更节省 FLOP)的事情!我们可以使用隐函数定理[Bertsekas 的《非线性规划,第二版》附录 A.25],它保证(在某些条件下)我们即将使用的数学对象的存在。本质上,我们在线性化解决方案处进行线性化,并迭代解这些线性方程以计算我们想要的导数。
再次考虑方程(x = f(a, x))和函数(x^)。我们想要评估向量-Jacobian 乘积,如(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^(a_0))。
至少在我们想要求微分的点(a_0)周围的开放邻域内,让我们假设方程(x^(a) = f(a, x^(a)))对所有(a)都成立。由于两边作为(a)的函数相等,它们的导数也必须相等,所以让我们分别对两边进行微分:
(\qquad \partial x^(a) = \partial_0 f(a, x^(a)) + \partial_1 f(a, x^(a)) \partial x^(a))。
设置(A = \partial_1 f(a_0, x^(a_0)))和(B = \partial_0 f(a_0, x^(a_0))),我们可以更简单地写出我们想要的数量为
(\qquad \partial x^(a_0) = B + A \partial x^(a_0)),
或者,通过重新排列,
(\qquad \partial x^*(a_0) = (I - A)^{-1} B)。
这意味着我们可以评估向量-Jacobian 乘积,如
(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B),
其中(w^\mathsf{T} = v^\mathsf{T} (I - A){-1}),或者等效地(w\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A),或者等效地(w\mathsf{T})是映射(u\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A)的不动点。最后一个描述使我们可以根据对fixed_point
的调用来编写fixed_point
的 VJP!此外,在展开(A)和(B)之后,我们可以看到我们只需要在((a_0, x^*(a_0)))处评估(f)的 VJP。
这里是要点:
from jax import vjp @partial(custom_vjp, nondiff_argnums=(0,)) def fixed_point(f, a, x_guess): def cond_fun(carry): x_prev, x = carry return jnp.abs(x_prev - x) > 1e-6 def body_fun(carry): _, x = carry return x, f(a, x) _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) return x_star def fixed_point_fwd(f, a, x_init): x_star = fixed_point(f, a, x_init) return x_star, (a, x_star) def fixed_point_rev(f, res, x_star_bar): a, x_star = res _, vjp_a = vjp(lambda a: f(a, x_star), a) a_bar, = vjp_a(fixed_point(partial(rev_iter, f), (a, x_star, x_star_bar), x_star_bar)) return a_bar, jnp.zeros_like(x_star) def rev_iter(f, packed, u): a, x_star, x_star_bar = packed _, vjp_x = vjp(lambda x: f(a, x), x_star) return x_star_bar + vjp_x(u)[0] fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.)) print(grad(grad(newton_sqrt))(2.))
0.35355338 -0.088388346
我们可以通过对 jnp.sqrt
进行微分来检查我们的答案,它使用了完全不同的实现:
print(grad(jnp.sqrt)(2.)) print(grad(grad(jnp.sqrt))(2.))
0.35355338 -0.08838835
这种方法的一个限制是参数f
不能涉及到任何参与微分的值。也就是说,你可能注意到我们在fixed_point
的参数列表中明确保留了参数a
。对于这种用例,考虑使用低级原语lax.custom_root
,它允许在闭合变量中进行带有自定义根查找函数的导数。
使用 jax.custom_jvp
和 jax.custom_vjp
API 的基本用法
使用 jax.custom_jvp
来定义前向模式(以及间接地,反向模式)规则
这里是使用 jax.custom_jvp
的典型基本示例,其中注释使用Haskell-like type signatures。
from jax import custom_jvp import jax.numpy as jnp # f :: a -> b @custom_jvp def f(x): return jnp.sin(x) # f_jvp :: (a, T a) -> (b, T b) def f_jvp(primals, tangents): x, = primals t, = tangents return f(x), jnp.cos(x) * t f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
from jax import jvp print(f(3.)) y, y_dot = jvp(f, (3.,), (1.,)) print(y) print(y_dot)
0.14112 0.14112 -0.9899925
简言之,我们从一个原始函数f
开始,它接受类型为a
的输入并产生类型为b
的输出。我们与之关联一个 JVP 规则函数f_jvp
,它接受一对输入,表示类型为a
的原始输入和类型为T a
的相应切线输入,并产生一对输出,表示类型为b
的原始输出和类型为T b
的切线输出。切线输出应该是切线输入的线性函数。
你还可以使用 f.defjvp
作为装饰器,就像这样
@custom_jvp def f(x): ... @f.defjvp def f_jvp(primals, tangents): ...
尽管我们只定义了一个 JVP 规则而没有 VJP 规则,但我们可以在f
上同时使用正向和反向模式的微分。JAX 会自动将切线值上的线性计算从我们的自定义 JVP 规则转置,高效地计算出 VJP,就好像我们手工编写了规则一样。
from jax import grad print(grad(f)(3.)) print(grad(grad(f))(3.))
-0.9899925 -0.14112
为了使自动转置工作,JVP 规则的输出切线必须是输入切线的线性函数。否则将引发转置错误。
多个参数的工作方式如下:
@custom_jvp def f(x, y): return x ** 2 * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0
defjvps
便捷包装器允许我们为每个参数单独定义一个 JVP,并分别计算结果后进行求和:
@custom_jvp def f(x): return jnp.sin(x) f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925
下面是一个带有多个参数的defjvps
示例:
@custom_jvp def f(x, y): return x ** 2 * y f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.)) print(grad(f, 0)(2., 3.)) # same as above print(grad(f, 1)(2., 3.))
12.0 12.0 4.0
简而言之,使用defjvps
,您可以传递None
值来指示特定参数的 JVP 为零:
@custom_jvp def f(x, y): return x ** 2 * y f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, None)
print(grad(f)(2., 3.)) print(grad(f, 0)(2., 3.)) # same as above print(grad(f, 1)(2., 3.))
12.0 12.0 0.0
使用关键字参数调用jax.custom_jvp
函数,或者编写具有默认参数的jax.custom_jvp
函数定义,只要能够根据通过标准库inspect.signature
机制检索到的函数签名映射到位置参数即可。
当您不执行微分时,函数f
的调用方式与未被jax.custom_jvp
修饰时完全一样:
@custom_jvp def f(x): print('called f!') # a harmless side-effect return jnp.sin(x) @f.defjvp def f_jvp(primals, tangents): print('called f_jvp!') # a harmless side-effect x, = primals t, = tangents return f(x), jnp.cos(x) * t
from jax import vmap, jit print(f(3.))
called f! 0.14112
print(vmap(f)(jnp.arange(3.))) print(jit(f)(3.))
called f! [0\. 0.84147096 0.9092974 ] called f! 0.14112
自定义的 JVP 规则在微分过程中被调用,无论是正向还是反向:
y, y_dot = jvp(f, (3.,), (1.,)) print(y_dot)
called f_jvp! called f! -0.9899925
print(grad(f)(3.))
called f_jvp! called f! -0.9899925
注意,f_jvp
调用f
来计算原始输出。在高阶微分的上下文中,每个微分变换的应用将只在规则调用原始f
来计算原始输出时使用自定义的 JVP 规则。(这代表一种基本的权衡,我们不能同时利用f
的评估中间值来制定规则并且使规则在所有高阶微分顺序中应用。)
grad(grad(f))(3.)
called f_jvp! called f_jvp! called f!
Array(-0.14112, dtype=float32, weak_type=True)
您可以使用 Python 控制流来使用jax.custom_jvp
:
@custom_jvp def f(x): if x > 0: return jnp.sin(x) else: return jnp.cos(x) @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents ans = f(x) if x > 0: return ans, 2 * x_dot else: return ans, 3 * x_dot
print(grad(f)(1.)) print(grad(f)(-1.))
2.0 3.0
JAX 中文文档(八)(5)https://developer.aliyun.com/article/1559666