JAX 中文文档(八)(4)https://developer.aliyun.com/article/1559665
使用jax.custom_vjp
来定义自定义的仅反向模式规则
虽然jax.custom_jvp
足以控制正向和通过 JAX 的自动转置控制反向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上述后两个示例问题中。我们可以通过jax.custom_vjp
来实现这一点。
from jax import custom_vjp import jax.numpy as jnp # f :: a -> b @custom_vjp def f(x): return jnp.sin(x) # f_fwd :: a -> (b, c) def f_fwd(x): return f(x), jnp.cos(x) # f_bwd :: (c, CT b) -> CT a def f_bwd(cos_x, y_bar): return (cos_x * y_bar,) f.defvjp(f_fwd, f_bwd)
from jax import grad print(f(3.)) print(grad(f)(3.))
0.14112 -0.9899925
换句话说,我们再次从接受类型为a
的输入并产生类型为b
的输出的原始函数f
开始。我们将与之关联两个函数f_fwd
和f_bwd
,它们描述了如何执行反向模式自动微分的正向和反向传递。
函数f_fwd
描述了前向传播,不仅包括原始计算,还包括要保存以供后向传播使用的值。其输入签名与原始函数f
完全相同,即它接受类型为a
的原始输入。但作为输出,它产生一对值,其中第一个元素是原始输出b
,第二个元素是类型为c
的任何“残余”数据,用于后向传播时存储。(这第二个输出类似于PyTorch 的 save_for_backward 机制。)
函数f_bwd
描述了反向传播。它接受两个输入,第一个是由f_fwd
生成的类型为c
的残差数据,第二个是对应于原始函数输出的类型为CT b
的输出共切线。它生成一个类型为CT a
的输出,表示原始函数输入对应的共切线。特别地,f_bwd
的输出必须是长度等于原始函数参数个数的序列(例如元组)。
多个参数的工作方式如下:
from jax import custom_vjp @custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
调用带有关键字参数的jax.custom_vjp
函数,或者编写带有默认参数的jax.custom_vjp
函数定义,只要可以根据标准库inspect.signature
机制清晰地映射到位置参数即可。
与jax.custom_jvp
类似,如果没有应用微分,则不会调用由f_fwd
和f_bwd
组成的自定义 VJP 规则。如果对函数进行评估,或者使用jit
、vmap
或其他非微分变换进行转换,则只调用f
。
@custom_vjp def f(x): print("called f!") return jnp.sin(x) def f_fwd(x): print("called f_fwd!") return f(x), jnp.cos(x) def f_bwd(cos_x, y_bar): print("called f_bwd!") return (cos_x * y_bar,) f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f! 0.14112
print(grad(f)(3.))
called f_fwd! called f! called f_bwd! -0.9899925
from jax import vjp y, f_vjp = vjp(f, 3.) print(y)
called f_fwd! called f! 0.14112
print(f_vjp(1.))
called f_bwd! (Array(-0.9899925, dtype=float32, weak_type=True),)
无法在 jax.custom_vjp
函数上使用前向模式自动微分,否则会引发错误:
from jax import jvp try: jvp(f, (3.,), (1.,)) except TypeError as e: print('ERROR! {}'.format(e))
called f_fwd! called f! ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
如果希望同时使用前向和反向模式,请使用jax.custom_jvp
。
我们可以使用jax.custom_vjp
与pdb
一起在反向传播中插入调试器跟踪:
import pdb @custom_vjp def debug(x): return x # acts like identity def debug_fwd(x): return x, x def debug_bwd(x, g): import pdb; pdb.set_trace() return g debug.defvjp(debug_fwd, debug_bwd)
def foo(x): y = x ** 2 y = debug(y) # insert pdb in corresponding backward pass step return jnp.sin(y)
jax.grad(foo)(3.) > <ipython-input-113-b19a2dc1abf7>(12)debug_bwd() -> return g (Pdb) p x Array(9., dtype=float32) (Pdb) p g Array(-0.91113025, dtype=float32) (Pdb) q
更多特性和细节
使用list
/ tuple
/ dict
容器(和其他 pytree)
你应该期望标准的 Python 容器如列表、元组、命名元组和字典可以正常工作,以及这些容器的嵌套版本。总体而言,任何pytrees都是允许的,只要它们的结构符合类型约束。
这里有一个使用jax.custom_jvp
的构造示例:
from collections import namedtuple Point = namedtuple("Point", ["x", "y"]) @custom_jvp def f(pt): x, y = pt.x, pt.y return {'a': x ** 2, 'b': (jnp.sin(x), jnp.cos(y))} @f.defjvp def f_jvp(primals, tangents): pt, = primals pt_dot, = tangents ans = f(pt) ans_dot = {'a': 2 * pt.x * pt_dot.x, 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)} return ans, ans_dot def fun(pt): dct = f(pt) return dct['a'] + dct['b'][0]
pt = Point(1., 2.) print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
还有一个类似的使用jax.custom_vjp
的构造示例:
@custom_vjp def f(pt): x, y = pt.x, pt.y return {'a': x ** 2, 'b': (jnp.sin(x), jnp.cos(y))} def f_fwd(pt): return f(pt), pt def f_bwd(pt, g): a_bar, (b0_bar, b1_bar) = g['a'], g['b'] x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar y_bar = -jnp.sin(pt.y) * b1_bar return (Point(x_bar, y_bar),) f.defvjp(f_fwd, f_bwd) def fun(pt): dct = f(pt) return dct['a'] + dct['b'][0]
pt = Point(1., 2.) print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
处理非可微参数
一些用例,如最后的示例问题,需要将非可微参数(如函数值参数)传递给具有自定义微分规则的函数,并且这些参数也需要传递给规则本身。在fixed_point
的情况下,函数参数f
就是这样一个非可微参数。类似的情况在jax.experimental.odeint
中也会出现。
jax.custom_jvp
与nondiff_argnums
使用可选的 nondiff_argnums
参数来指示类似这些的参数给 jax.custom_jvp
。以下是一个带有 jax.custom_jvp
的例子:
from functools import partial @partial(custom_jvp, nondiff_argnums=(0,)) def app(f, x): return f(x) @app.defjvp def app_jvp(f, primals, tangents): x, = primals x_dot, = tangents return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
注意这里的陷阱:无论这些参数在参数列表的哪个位置出现,它们都放置在相应 JVP 规则签名的起始位置。这里有另一个例子:
@partial(custom_jvp, nondiff_argnums=(0, 2)) def app2(f, x, g): return f(g((x))) @app2.defjvp def app2_jvp(f, g, primals, tangents): x, = primals x_dot, = tangents return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
nondiff_argnums
与 jax.custom_vjp
对于 jax.custom_vjp
也有类似的选项,类似地,非可微参数的约定是它们作为 _bwd
规则的第一个参数传递,无论它们出现在原始函数签名的哪个位置。 _fwd
规则的签名保持不变 - 它与原始函数的签名相同。以下是一个例子:
@partial(custom_vjp, nondiff_argnums=(0,)) def app(f, x): return f(x) def app_fwd(f, x): return f(x), x def app_bwd(f, x, g): return (5 * g,) app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
请参见上面的 fixed_point
以获取另一个用法示例。
对于具有整数 dtype 的数组值参数,不需要使用 nondiff_argnums
**。相反,nondiff_argnums
应仅用于不对应 JAX 类型(实质上不对应数组类型)的参数值,如 Python 可调用对象或字符串。如果 JAX 检测到由 nondiff_argnums
指示的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient
函数是不使用 nondiff_argnums
处理整数 dtype 数组参数的良好示例。