JAX 中文文档(八)(4)

简介: JAX 中文文档(八)

JAX 中文文档(八)(3)https://developer.aliyun.com/article/1559664


迭代实现的隐式函数微分

这个例子涉及到了数学中的深层问题!

另一个应用jax.custom_vjp是对可通过jitvmap等转换为 JAX 但由于某些原因不易 JAX 可区分的函数进行反向模式微分,也许是因为涉及lax.while_loop。(无法生成 XLA HLO 程序有效计算 XLA HLO While 循环的反向模式导数,因为这将需要具有无界内存使用的程序,这在 XLA HLO 中是不可能表达的,至少不是通过通过 infeed/outfeed 的副作用交互。)

例如,考虑这个fixed_point例程,它通过在while_loop中迭代应用函数来计算一个不动点:

from jax.lax import while_loop
def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6
  def body_fun(carry):
    _, x = carry
    return x, f(a, x)
  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star 

这是一种通过迭代应用函数(x_{t+1} = f(a, x_t))来数值解方程(x = f(a, x))以计算(x)的迭代过程,直到(x_{t+1})足够接近(x_t)。结果(x^)取决于参数(a),因此我们可以认为存在一个由方程(x = f(a, x))隐式定义的函数(a \mapsto x^(a))。

我们可以使用fixed_point运行迭代过程以收敛,例如运行牛顿法来计算平方根,只执行加法、乘法和除法:

def newton_sqrt(a):
  update = lambda a, x: 0.5 * (x + a / x)
  return fixed_point(update, a, a) 
print(newton_sqrt(2.)) 
1.4142135 

我们也可以对函数进行vmapjit处理:

print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.]))) 
[1\.        1.4142135 1.7320509 2\.       ] 

由于while_loop,我们无法应用反向模式自动微分,但事实证明我们也不想这样做:我们可以利用数学结构做一些更节省内存(在这种情况下也更节省  FLOP)的事情!我们可以使用隐函数定理[Bertsekas 的《非线性规划,第二版》附录  A.25],它保证(在某些条件下)我们即将使用的数学对象的存在。本质上,我们在线性化解决方案处进行线性化,并迭代解这些线性方程以计算我们想要的导数。

再次考虑方程(x = f(a, x))和函数(x^)。我们想要评估向量-Jacobian 乘积,如(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^(a_0))。

至少在我们想要求微分的点(a_0)周围的开放邻域内,让我们假设方程(x^(a) = f(a, x^(a)))对所有(a)都成立。由于两边作为(a)的函数相等,它们的导数也必须相等,所以让我们分别对两边进行微分:

(\qquad \partial x^(a) = \partial_0 f(a, x^(a)) + \partial_1 f(a, x^(a)) \partial x^(a))。

设置(A = \partial_1 f(a_0, x^(a_0)))和(B = \partial_0 f(a_0, x^(a_0))),我们可以更简单地写出我们想要的数量为

(\qquad \partial x^(a_0) = B + A \partial x^(a_0)),

或者,通过重新排列,

(\qquad \partial x^*(a_0) = (I - A)^{-1} B)。

这意味着我们可以评估向量-Jacobian 乘积,如

(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B),

其中(w^\mathsf{T} = v^\mathsf{T} (I - A){-1}),或者等效地(w\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A),或者等效地(w\mathsf{T})是映射(u\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A)的不动点。最后一个描述使我们可以根据对fixed_point的调用来编写fixed_point的 VJP!此外,在展开(A)和(B)之后,我们可以看到我们只需要在((a_0, x^*(a_0)))处评估(f)的 VJP。

这里是要点:

from jax import vjp
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6
  def body_fun(carry):
    _, x = carry
    return x, f(a, x)
  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star
def fixed_point_fwd(f, a, x_init):
  x_star = fixed_point(f, a, x_init)
  return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
  a, x_star = res
  _, vjp_a = vjp(lambda a: f(a, x_star), a)
  a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
                             (a, x_star, x_star_bar),
                             x_star_bar))
  return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
  a, x_star, x_star_bar = packed
  _, vjp_x = vjp(lambda x: f(a, x), x_star)
  return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev) 
print(newton_sqrt(2.)) 
1.4142135 
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.)) 
0.35355338
-0.088388346 

我们可以通过对 jnp.sqrt 进行微分来检查我们的答案,它使用了完全不同的实现:

print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.)) 
0.35355338
-0.08838835 

这种方法的一个限制是参数f不能涉及到任何参与微分的值。也就是说,你可能注意到我们在fixed_point的参数列表中明确保留了参数a。对于这种用例,考虑使用低级原语lax.custom_root,它允许在闭合变量中进行带有自定义根查找函数的导数。

使用 jax.custom_jvpjax.custom_vjp API 的基本用法

使用 jax.custom_jvp 来定义前向模式(以及间接地,反向模式)规则

这里是使用 jax.custom_jvp 的典型基本示例,其中注释使用Haskell-like type signatures

from jax import custom_jvp
import jax.numpy as jnp
# f :: a -> b
@custom_jvp
def f(x):
  return jnp.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t
f.defjvp(f_jvp) 
<function __main__.f_jvp(primals, tangents)> 
from jax import jvp
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot) 
0.14112
0.14112
-0.9899925 

简言之,我们从一个原始函数f开始,它接受类型为a的输入并产生类型为b的输出。我们与之关联一个 JVP 规则函数f_jvp,它接受一对输入,表示类型为a的原始输入和类型为T a的相应切线输入,并产生一对输出,表示类型为b的原始输出和类型为T b的切线输出。切线输出应该是切线输入的线性函数。

你还可以使用 f.defjvp 作为装饰器,就像这样

@custom_jvp
def f(x):
  ...
@f.defjvp
def f_jvp(primals, tangents):
  ... 

尽管我们只定义了一个 JVP 规则而没有 VJP 规则,但我们可以在f上同时使用正向和反向模式的微分。JAX 会自动将切线值上的线性计算从我们的自定义 JVP 规则转置,高效地计算出 VJP,就好像我们手工编写了规则一样。

from jax import grad
print(grad(f)(3.))
print(grad(grad(f))(3.)) 
-0.9899925
-0.14112 

为了使自动转置工作,JVP 规则的输出切线必须是输入切线的线性函数。否则将引发转置错误。

多个参数的工作方式如下:

@custom_jvp
def f(x, y):
  return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
  return primal_out, tangent_out 
print(grad(f)(2., 3.)) 
12.0 

defjvps便捷包装器允许我们为每个参数单独定义一个 JVP,并分别计算结果后进行求和:

@custom_jvp
def f(x):
  return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t) 
print(grad(f)(3.)) 
-0.9899925 

下面是一个带有多个参数的defjvps示例:

@custom_jvp
def f(x, y):
  return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          lambda y_dot, primal_out, x, y: x ** 2 * y_dot) 
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.)) 
12.0
12.0
4.0 

简而言之,使用defjvps,您可以传递None值来指示特定参数的 JVP 为零:

@custom_jvp
def f(x, y):
  return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          None) 
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.)) 
12.0
12.0
0.0 

使用关键字参数调用jax.custom_jvp函数,或者编写具有默认参数的jax.custom_jvp函数定义,只要能够根据通过标准库inspect.signature机制检索到的函数签名映射到位置参数即可。

当您不执行微分时,函数f的调用方式与未被jax.custom_jvp修饰时完全一样:

@custom_jvp
def f(x):
  print('called f!')  # a harmless side-effect
  return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
  print('called f_jvp!')  # a harmless side-effect
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t 
from jax import vmap, jit
print(f(3.)) 
called f!
0.14112 
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.)) 
called f!
[0\.         0.84147096 0.9092974 ]
called f!
0.14112 

自定义的 JVP 规则在微分过程中被调用,无论是正向还是反向:

y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot) 
called f_jvp!
called f!
-0.9899925 
print(grad(f)(3.)) 
called f_jvp!
called f!
-0.9899925 

注意,f_jvp调用f来计算原始输出。在高阶微分的上下文中,每个微分变换的应用将只在规则调用原始f来计算原始输出时使用自定义的 JVP 规则。(这代表一种基本的权衡,我们不能同时利用f的评估中间值来制定规则并且使规则在所有高阶微分顺序中应用。)

grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f! 
Array(-0.14112, dtype=float32, weak_type=True) 

您可以使用 Python 控制流来使用jax.custom_jvp

@custom_jvp
def f(x):
  if x > 0:
    return jnp.sin(x)
  else:
    return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  if x > 0:
    return ans, 2 * x_dot
  else:
    return ans, 3 * x_dot 
print(grad(f)(1.))
print(grad(f)(-1.)) 
2.0
3.0 


JAX 中文文档(八)(5)https://developer.aliyun.com/article/1559666

相关文章
|
4月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
29 1
|
4月前
|
机器学习/深度学习 PyTorch API
JAX 中文文档(六)(1)
JAX 中文文档(六)
36 0
JAX 中文文档(六)(1)
|
4月前
|
编译器 API C++
JAX 中文文档(三)(3)
JAX 中文文档(三)
29 0
|
4月前
|
机器学习/深度学习 算法 异构计算
JAX 中文文档(七)(2)
JAX 中文文档(七)
29 0
|
4月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
37 0
|
4月前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
35 0
|
4月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
24 0
|
4月前
|
存储 Python
JAX 中文文档(十)(3)
JAX 中文文档(十)
29 0
|
4月前
|
Python
JAX 中文文档(十)(5)
JAX 中文文档(十)
27 0
|
4月前
|
存储 缓存 测试技术
JAX 中文文档(三)(5)
JAX 中文文档(三)
51 0