JAX 中文文档(八)(5)

简介: JAX 中文文档(八)

JAX 中文文档(八)(4)https://developer.aliyun.com/article/1559665


使用jax.custom_vjp来定义自定义的仅反向模式规则

虽然jax.custom_jvp足以控制正向和通过 JAX 的自动转置控制反向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上述后两个示例问题中。我们可以通过jax.custom_vjp来实现这一点。

from jax import custom_vjp
import jax.numpy as jnp
# f :: a -> b
@custom_vjp
def f(x):
  return jnp.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
  return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd) 
from jax import grad
print(f(3.))
print(grad(f)(3.)) 
0.14112
-0.9899925 

换句话说,我们再次从接受类型为a的输入并产生类型为b的输出的原始函数f开始。我们将与之关联两个函数f_fwdf_bwd,它们描述了如何执行反向模式自动微分的正向和反向传递。

函数f_fwd描述了前向传播,不仅包括原始计算,还包括要保存以供后向传播使用的值。其输入签名与原始函数f完全相同,即它接受类型为a的原始输入。但作为输出,它产生一对值,其中第一个元素是原始输出b,第二个元素是类型为c的任何“残余”数据,用于后向传播时存储。(这第二个输出类似于PyTorch 的 save_for_backward 机制。)

函数f_bwd描述了反向传播。它接受两个输入,第一个是由f_fwd生成的类型为c的残差数据,第二个是对应于原始函数输出的类型为CT b的输出共切线。它生成一个类型为CT a的输出,表示原始函数输入对应的共切线。特别地,f_bwd的输出必须是长度等于原始函数参数个数的序列(例如元组)。

多个参数的工作方式如下:

from jax import custom_vjp
@custom_vjp
def f(x, y):
  return jnp.sin(x) * y
def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd) 
print(grad(f)(2., 3.)) 
-1.2484405 

调用带有关键字参数的jax.custom_vjp函数,或者编写带有默认参数的jax.custom_vjp函数定义,只要可以根据标准库inspect.signature机制清晰地映射到位置参数即可。

jax.custom_jvp类似,如果没有应用微分,则不会调用由f_fwdf_bwd组成的自定义 VJP 规则。如果对函数进行评估,或者使用jitvmap或其他非微分变换进行转换,则只调用f

@custom_vjp
def f(x):
  print("called f!")
  return jnp.sin(x)
def f_fwd(x):
  print("called f_fwd!")
  return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
  print("called f_bwd!")
  return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd) 
print(f(3.)) 
called f!
0.14112 
print(grad(f)(3.)) 
called f_fwd!
called f!
called f_bwd!
-0.9899925 
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y) 
called f_fwd!
called f!
0.14112 
print(f_vjp(1.)) 
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),) 

无法在 jax.custom_vjp 函数上使用前向模式自动微分,否则会引发错误:

from jax import jvp
try:
  jvp(f, (3.,), (1.,))
except TypeError as e:
  print('ERROR! {}'.format(e)) 
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function. 

如果希望同时使用前向和反向模式,请使用jax.custom_jvp

我们可以使用jax.custom_vjppdb一起在反向传播中插入调试器跟踪:

import pdb
@custom_vjp
def debug(x):
  return x  # acts like identity
def debug_fwd(x):
  return x, x
def debug_bwd(x, g):
  import pdb; pdb.set_trace()
  return g
debug.defvjp(debug_fwd, debug_bwd) 
def foo(x):
  y = x ** 2
  y = debug(y)  # insert pdb in corresponding backward pass step
  return jnp.sin(y) 
jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q 

更多特性和细节

使用list / tuple / dict容器(和其他 pytree)

你应该期望标准的 Python 容器如列表、元组、命名元组和字典可以正常工作,以及这些容器的嵌套版本。总体而言,任何pytrees都是允许的,只要它们的结构符合类型约束。

这里有一个使用jax.custom_jvp的构造示例:

from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
  pt, = primals
  pt_dot, =  tangents
  ans = f(pt)
  ans_dot = {'a': 2 * pt.x * pt_dot.x,
             'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
  return ans, ans_dot
def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0] 
pt = Point(1., 2.)
print(f(pt)) 
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))} 
print(grad(fun)(pt)) 
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True)) 

还有一个类似的使用jax.custom_vjp的构造示例:

@custom_vjp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
  return f(pt), pt
def f_bwd(pt, g):
  a_bar, (b0_bar, b1_bar) = g['a'], g['b']
  x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
  y_bar = -jnp.sin(pt.y) * b1_bar
  return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0] 
pt = Point(1., 2.)
print(f(pt)) 
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))} 
print(grad(fun)(pt)) 
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True)) 

处理非可微参数

一些用例,如最后的示例问题,需要将非可微参数(如函数值参数)传递给具有自定义微分规则的函数,并且这些参数也需要传递给规则本身。在fixed_point的情况下,函数参数f就是这样一个非可微参数。类似的情况在jax.experimental.odeint中也会出现。

jax.custom_jvpnondiff_argnums

使用可选的 nondiff_argnums 参数来指示类似这些的参数给 jax.custom_jvp。以下是一个带有 jax.custom_jvp 的例子:

from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(x), 2. * x_dot 
print(app(lambda x: x ** 3, 3.)) 
27.0 
print(grad(app, 1)(lambda x: x ** 3, 3.)) 
2.0 

注意这里的陷阱:无论这些参数在参数列表的哪个位置出现,它们都放置在相应 JVP 规则签名的起始位置。这里有另一个例子:

@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
  return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(g(x)), 3. * x_dot 
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y)) 
3375.0 
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y)) 
3.0 
nondiff_argnumsjax.custom_vjp

对于 jax.custom_vjp 也有类似的选项,类似地,非可微参数的约定是它们作为 _bwd 规则的第一个参数传递,无论它们出现在原始函数签名的哪个位置。 _fwd 规则的签名保持不变 - 它与原始函数的签名相同。以下是一个例子:

@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)
def app_fwd(f, x):
  return f(x), x
def app_bwd(f, x, g):
  return (5 * g,)
app.defvjp(app_fwd, app_bwd) 
print(app(lambda x: x ** 2, 4.)) 
16.0 
print(grad(app, 1)(lambda x: x ** 2, 4.)) 
5.0 

请参见上面的 fixed_point 以获取另一个用法示例。

对于具有整数 dtype 的数组值参数,不需要使用 nondiff_argnums **。相反,nondiff_argnums 应仅用于不对应 JAX 类型(实质上不对应数组类型)的参数值,如 Python 可调用对象或字符串。如果 JAX 检测到由 nondiff_argnums 指示的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient 函数是不使用 nondiff_argnums 处理整数 dtype 数组参数的良好示例。

相关文章
|
9天前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
13 1
|
9天前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
12 1
|
9天前
|
机器学习/深度学习 测试技术 索引
JAX 中文文档(二)(4)
JAX 中文文档(二)
9 0
|
9天前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
7 0
|
9天前
|
Serverless C++ Python
JAX 中文文档(九)(5)
JAX 中文文档(九)
15 0
|
9天前
|
机器学习/深度学习 存储 并行计算
JAX 中文文档(七)(3)
JAX 中文文档(七)
10 0
|
9天前
|
机器学习/深度学习 索引 Python
JAX 中文文档(四)(1)
JAX 中文文档(四)
10 0
|
9天前
|
机器学习/深度学习 程序员 编译器
JAX 中文文档(三)(1)
JAX 中文文档(三)
8 0
|
9天前
|
并行计算 编译器
JAX 中文文档(六)(4)
JAX 中文文档(六)
9 0
|
9天前
|
数据可视化 TensorFlow 算法框架/工具
JAX 中文文档(三)(2)
JAX 中文文档(三)
10 0