使用jax.checkpoint
控制自动微分的保存数值(又名jax.remat
)
原文:
jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html
import jax import jax.numpy as jnp
简而言之
使用jax.checkpoint
装饰器(别名为jax.remat
),结合jax.grad
来控制前向传播时保存哪些中间值,以及在反向传播时重新计算哪些中间值,从而在内存和 FLOP 之间进行权衡。
不要错过关于jax.checkpoint
如何与jax.jit
交互的实用说明。
如果不使用jax.checkpoint
,jax.grad(f)(x)
的前向传播将保存雅可比系数和其他中间值以供后向传播使用。我们称这些保存的值为残差:
def g(W, x): y = jnp.dot(W, x) return jnp.sin(y) def f(W1, W2, W3, x): x = g(W1, x) x = g(W2, x) x = g(W3, x) return x W1 = jnp.ones((5, 4)) W2 = jnp.ones((6, 5)) W3 = jnp.ones((7, 6)) x = jnp.ones(4) # Inspect the 'residual' values to be saved on the forward pass # if we were to evaluate `jax.grad(f)(W1, W2, W3, x)` from jax.ad_checkpoint import print_saved_residuals jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1' f32[6,5] from the argument 'W2' f32[7,6] from the argument 'W3' f32[4] from the argument 'x' f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g) f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g) f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g) f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g) f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
通过对子函数应用jax.checkpoint
,无论是作为装饰器还是在特定的应用站点,我们都强制 JAX 不保存该子函数的任何残差。相反,只有jax.checkpoint
装饰的函数的输入可能会被保存,并且在反向传播时从这些输入重新计算任何消耗的残差:
def f2(W1, W2, W3, x): x = jax.checkpoint(g)(W1, x) x = jax.checkpoint(g)(W2, x) x = jax.checkpoint(g)(W3, x) return x jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1' f32[6,5] from the argument 'W2' f32[7,6] from the argument 'W3' f32[4] from the argument 'x' f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g) f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
这里保存了两个sin
应用的值,因为它们是jax.checkpoint
装饰的g
函数后续应用的参数,并且jax.checkpoint
装饰的函数的输入可能会被保存。但没有保存任何cos
应用的值。
要控制哪些值可保存,而无需编辑要区分的函数的定义,您可以使用重新材料化策略。以下是一个例子,仅保存没有批次维度的dot
操作的结果(因为它们通常是 FLOP 限制的,因此值得保存而不是重新计算):
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1' f32[6,5] from the argument 'W2' f32[7,6] from the argument 'W3' f32[4] from the argument 'x' f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g) f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g) f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
您还可以使用策略来引用使用jax.ad_checkpoint.checkpoint_name
命名的中间值:
from jax.ad_checkpoint import checkpoint_name def f4(W1, W2, W3, x): x = checkpoint_name(g(W1, x), name='a') x = checkpoint_name(g(W2, x), name='b') x = checkpoint_name(g(W3, x), name='c') return x f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a')) jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1' f32[6,5] from the argument 'W2' f32[7,6] from the argument 'W3' f32[4] from the argument 'x' f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
在玩弄这些玩具示例时,我们可以使用在此笔记本中定义的print_fwd_bwd
实用程序更详细地了解正在进行的操作:
from jax.tree_util import tree_flatten, tree_unflatten from rich.console import Console from rich.table import Table import rich.text def print_fwd_bwd(f, *args, **kwargs) -> None: args, in_tree = tree_flatten((args, kwargs)) def f_(*args): args, kwargs = tree_unflatten(in_tree, args) return f(*args, **kwargs) fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr y, f_vjp = jax.vjp(f_, *args) res, in_tree = tree_flatten(f_vjp) def g_(*args): *res, y = args f_vjp = tree_unflatten(in_tree, res) return f_vjp(y) bwd = jax.make_jaxpr(g_)(*res, y).jaxpr table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None) table.add_row("[bold green]forward computation:", "[bold green]backward computation:") table.add_row(rich.text.Text.from_ansi(str(fwd)), rich.text.Text.from_ansi(str(bwd))) console = Console(width=240, force_jupyter=True) console.print(table) def _renderable_repr(self): return self.html rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint: print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4] e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7]. let f:f32[5] = sin e k:f32[7] = mul j a g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b i:f32[6] = sin h n:f32[6] = mul l d j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e l:f32[7] = sin k q:f32[5] = mul o g m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i in (l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h in (s, p, m, r) }
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable: print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation: { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[ f:f32[5] = sin e differentiated=True g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6] h:f32[6] = sin g s:f32[4] t:f32[7]. let i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h u:f32[5] = sin m j:f32[7] = sin i v:f32[5] = cos m in (j, e, g, i, a, b, c, d) } w:f32[6] = sin n x:f32[6] = cos n y:f32[7] = cos o z:f32[7] = mul t y ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r bb:f32[6] = mul ba x bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q bd:f32[5] = mul bc v be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w in (bf, bg, bh, be) } policy=<function dot_with_no_batch_dims at 0x7f5e469b1700> prevent_cse=True ] a b c d e f g h in (i, j, k, l) }
让我们一步一步地思考
您可能希望首先(重新)阅读自动微分手册第一部分。
jax.checkpoint
的基础知识
在jax.linearize
和jax.vjp
中,如何以及何时计算某些值有灵活性。不同的选择可以在内存使用和 FLOP 之间进行权衡。JAX 通过jax.checkpoint
提供了对这些选择的控制。
其中之一是在前向传播时执行雅可比系数计算,即在输入可用时立即进行,或者在反向传播时,在需要系数之前进行。考虑sin_vjp
的例子:
def sin_vjp(x): y = jnp.sin(x) cos_x = jnp.cos(x) return y, lambda y_bar: cos_x * y_bar
在反向传播时,另一种有效的实现方式是计算jnp.cos(x)
的值,而不是在前向传播时:
def sin_vjp2(x): y = jnp.sin(x) return y, lambda y_bar: jnp.cos(x) * y_bar
对于这个特定的函数,两个版本使用的内存量是相同的,尽管我们减少了原始计算的 FLOP 并增加了余切计算的 FLOP。
当涉及函数组合时,我们还有另一种选择。回顾我们的两个函数组合的 VJP 规则:
def f(x): y = g(x) z = h(y) return z def f_vjp(x): y, g_vjp = jax.vjp(g, x) z, h_vjp = jax.vjp(h, y) def f_bwd(z_bar): y_bar, = h_vjp(z_bar) x_bar, = g_vjp(y_bar) return x_bar return z, f_bwd
另一种选择是:
def f_vjp_checkpoint(x): y = g(x) z, h_vjp = jax.vjp(h, y) def f_bwd2(z_bar): y_bar, = h_vjp(z_bar) _, g_vjp = jax.vjp(g, x) x_bar, = g_vjp(y_bar) return x_bar return z, f_bwd2
换句话说,这种替代实现不会在前向传播中计算g_vjp
或其闭包中的残差值。而是只在后向传播f_bwd2
中计算它们。这意味着f_vjp_checkpoint
需要更少的内存:如果g
和h
每个都需要类似量级的内存来存储其残差,远大于x
,那么由f_vjp_checkpoint(x)
生成的函数所需的内存量仅为f_vjp(x)
的一半!
我们所付出的代价是冗余工作:在f_bwd2
中,我们必须重新评估g(x)
作为jax.vjp(g, x)
的一部分,只是为了丢弃它的值(在下划线变量的行中_, g_vjp = jax.vjp(g, x)
)。
我们可以在自动微分中实现这种 VJP 行为,而不必直接编写 VJP 函数,而是通过在原始函数f
的另一种定义中使用jax.checkpoint
来实现:
def f_checkpoint(x): y = jax.checkpoint(g)(x) z = h(y) return z
换句话说,我们将jax.checkpoint
应用于f
的第一阶段g
,而不是f
本身。这样,当我们评估jax.grad(f_checkpoint)(x)
时,我们会得到如下计算:
- 运行
g
的前向传播,丢弃残差值; - 运行
h
的前向传播,保存残差; - 运行
h
的后向传播,使用步骤 2 中的残差; - 重新运行
g
的前向传播,保存残差; - 运行
g
的后向传播,使用步骤 4 中的残差。
换句话说,通过评估jax.grad(f_checkpoint)(x)
,我们会得到与如下计算相同的结果:
def f_checkpoint_grad(x): y = g(x) # step 1 _, h_vjp = jax.vjp(h)(y) # step 2 y_bar, = h_vjp(1.0) # step 3 _, g_vjp = jax.vjp(g, x) # step 4 x_bar, = g_vjp(y_bar) # step 5 return x_bar
通常情况下,jax.checkpoint(foo)
是一个新函数,其输入输出行为与foo
相同,但在自动微分下行为不同,特别是在jax.linearize
和jax.vjp
(以及它们的包装器,如jax.grad
)中,但不包括jax.jvp
。在求导时,只有经过jax.checkpoint
的函数的输入会在前向传播时存储;在后向传播时,会重新计算残差(即来自foo
及其雅可比系数值的中间值,这些值在后向传播时需要重新计算)。
注意,如果f = lambda x: h(g(x))
是我们想要求导的函数,即如果我们想应用jax.grad(f)
,那么对f
本身应用jax.checkpoint
不会节省任何内存。这是因为评估jax.grad(jax.checkpoint(f))(x)
会导致如下计算:
- 运行前向传播,丢弃所有残差;
- 立即重新运行前向传播,保存残差;
- 运行后向传播,使用步骤 2 中的残差。
换句话说,代码中我们会有类似这样的东西:
def f_grad_bad(x): _ = f(x) # step 1 _, f_vjp = jax.vjp(f, x) # step 2 x_bar, = f_vjp(1.0) # step 3 return x_bar
如果对h
的第二阶段应用jax.checkpoint
,我们也不会获得任何内存节省。这是因为评估jax.grad(lambda x: jax.checkpoint(h)(g(x)))
会导致如下计算:
- 运行
g
的前向传播,保存残差; - 运行
h
的前向传播,丢弃残差; - 立即重新运行
h
的前向传播,保存残差; - 运行
h
的后向传播,使用步骤 3 中的残差; - 运行
g
的后向传播,消耗步骤 1 中的剩余项。
这样,在代码中,我们会有类似以下的内容:
def f_grad_bad2(x): y, g_vjp = jax.vjp(g, x) # step 1 z = h(y) # step 2 _, h_vjp = jax.vjp(h, y) # step 3 y_bar, = h_vjp(1.0) # step 3 x_bar, = g_vjp(y_bar) # step 5 return x_bar
稍微更一般地说,如果我们有一个函数链组合,如f = lambda x: f3(f2(f1(x)))
,并且我们有兴趣评估jax.grad(f)
,我们可以说:
- 我们不应将
jax.checkpoint
应用于整个函数f
,因为这不会节省任何内存(并且会执行浪费的重新计算); - 我们不应将
jax.checkpoint
应用于最后一个子函数f3
,因为这不会节省任何内存(并且会执行浪费的重新计算); - 我们可以将
jax.checkpoint
应用于f1
、f2
或它们的组合lambda x: f2(f1(x))
,因为这些任意一个都可能节省内存,并且会表达不同的内存/重新计算折衷。
什么可以保存的自定义策略
到目前为止所展示的,使用jax.checkpoint
会从一个极端切换到另一个:
- 没有
jax.checkpoint
,JAX 的自动微分倾向于在前向传播中计算尽可能多的内容,并为后向传播存储它; - 使用
jax.checkpoint
装饰器,我们在前向传播中尽量少计算,并根据需要在后向传播中重新计算值。
要在这两个极端之间操作,保存某些东西而不保存其他东西,我们可以在子函数上谨慎地放置jax.checkpoint
装饰器。但这需要编辑要求微分的函数,例如模型代码,这可能不方便。也很难对变体进行实验。
因此,一个替代方法是使用jax.checkpoint
的policy
参数。策略是一个可调用对象(即一个函数),它以一种类型级别的原始应用规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为剩余项(或者必须在(共)切向计算中根据需要重新计算)。为了编写健壮的代码,应从jax.checkpoint_policies
的属性中选择策略,例如jax.checkpoint_policies.dots_with_no_batch_dims_saveable
,因为编写自定义策略可调用对象的 API 被认为是内部的。
例如,考虑要微分的这个函数:
def loss(params, x, y): return jnp.sum((predict(params, x) - y)**2) def predict(params, x): *Ws, Wlast = params for W in Ws: x = layer(W, x) x = jnp.dot(Wlast, x) return x def layer(W, x): return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4)) params = [W1, W2, W3] x = jnp.ones(4) y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4] from the argument 'x' f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
而不是在前向传播中保存这么多值,也许我们只想保存没有批处理维度的矩阵乘法结果(因为它们可能是 FLOP 而不是内存绑定)。我们可以使用策略jax.checkpoint_policies.dots_with_no_batch_dims_saveable
来实现这一点:
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4] from the argument 'x' f32[4] from the argument 'y' f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
还要注意,通过提供一个策略,我们无需编辑定义loss
、predict
或layer
的代码。如果我们希望在调用代码(例如训练脚本)中进行策略实验而不更改库代码(例如神经网络库),这特别方便。
一些策略可以引用名为jax.ad_checkpoint.checkpoint_name
的值:
from jax.ad_checkpoint import checkpoint_name def predict(params, x): *Ws, Wlast = params for i, W in enumerate(Ws): x = layer(W, x) x = checkpoint_name(x, name=f'layer{i}_output') x = jnp.dot(Wlast, x) return x
单独看,checkpoint_name
只是一个身份函数。但因为某些策略函数知道如何查找它们,我们可以使用这些名称来控制 checkpoint_name
输出的某些值是否被视为可保存的:
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4] from the argument 'x' f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict) f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer) f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict) f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output')) print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4,4] from the argument 'params' f32[4] from the argument 'x' f32[4] from the argument 'y'
另一个涉及名称的策略是 jax.checkpoint_policies.save_only_these_names
。
某些策略包括:
everything_saveable
(默认策略,就像根本没有使用jax.checkpoint
一样)nothing_saveable
(即重新生成所有内容,就像根本没有使用自定义策略一样)dots_saveable
或其别名checkpoint_dots
dots_with_no_batch_dims_saveable
或其别名checkpoint_dots_with_no_batch_dims
save_anything_but_these_names
(保存任何值,但不包括具有给定名称的checkpoint_name
输出)save_any_names_but_these
(仅保存命名值,即checkpoint_name
的任何输出,但不包括给定名称)save_only_these_names
(仅保存命名值,并且仅限于给定的名称)
策略仅指示可保存的内容;只有在反向传播实际需要时才会保存值。
高级:递归的 jax.checkpoint
通过适当地应用 jax.checkpoint
,可以表达许多内存使用和(重新)计算之间的权衡。一个令人惊讶的例子是 递归 检查点处理,在这种情况下,我们将 jax.checkpoint
应用于一个函数,该函数本身调用以 jax.checkpoint
装饰的函数,以便从 (D) 函数链的组合中内存使用按 (\mathcal{O}(\log_2 D)) 而非 (\mathcal{O}(D)) 缩放。
作为一个玩具例子,考虑多个 jnp.sin
函数的链式组合:
def chain_compose(funs): def f(x): for fun in funs: x = fun(x) return x return f f = chain_compose([jnp.sin] * 8) print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
通常来说,存储的残差数量与链的长度成线性比例:
f = chain_compose([jnp.sin] * 16) print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
但我们可以递归地应用 jax.checkpoint
来改善缩放效果:
def recursive_checkpoint(funs): if len(funs) == 1: return funs[0] elif len(funs) == 2: f1, f2 = funs return lambda x: f1(f2(x)) else: f1 = recursive_checkpoint(funs[:len(funs)//2]) f2 = recursive_checkpoint(funs[len(funs)//2:]) return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8) print_saved_residuals(f, 3.)
f32[] from the argument 'x' f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>) f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>) f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16) print_saved_residuals(f, 3.)
f32[] from the argument 'x' f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>) f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>) f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>) f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
这里的成本,与通常一样,是重新计算:特别是,我们最终要执行 (\mathcal{O}(\log_2 D)) 倍的 FLOPs:
f = chain_compose([jnp.sin] * 8) print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let b:f32[] = sin a j:f32[] = mul i a c:f32[] = cos a k:f32[] = mul j b d:f32[] = sin b l:f32[] = mul k c e:f32[] = cos b m:f32[] = mul l d f:f32[] = sin d n:f32[] = mul m e g:f32[] = cos d o:f32[] = mul n f h:f32[] = sin f p:f32[] = mul o g i:f32[] = cos f q:f32[] = mul p h j:f32[] = sin h in (q,) } k:f32[] = cos h l:f32[] = sin j m:f32[] = cos j n:f32[] = sin l o:f32[] = cos l p:f32[] = sin n q:f32[] = cos n in (p, q, o, m, k, i, g, e, c) }
f = recursive_checkpoint([jnp.sin] * 8) print_fwd_bwd(f, 3.)
forward computation: backward computation: { lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let b:f32[] = remat2[ e:f32[] = mul d a differentiated=False f:f32[] = mul e b jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[ policy=None differentiated=True prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let ] a j:f32[] = sin h f:f32[] = sin b k:f32[] = cos h g:f32[] = sin f l:f32[] = cos j h:f32[] = sin g m:f32[] = mul i l i:f32[] = sin h n:f32[] = mul m k j:f32[] = sin i in (n,) } k:f32[] = cos i policy=None l:f32[] = sin j prevent_cse=True m:f32[] = cos j ] c f in (l, m, k, g, a) } o:f32[] = remat2[ differentiated=True jaxpr={ lambda ; p:f32[] q:f32[]. let r:f32[] = sin p s:f32[] = sin r t:f32[] = sin s u:f32[] = cos s v:f32[] = cos t w:f32[] = mul q v x:f32[] = mul w u y:f32[] = remat2[ differentiated=True jaxpr={ lambda ; z:f32[] ba:f32[]. let bb:f32[] = sin z bc:f32[] = cos z bd:f32[] = cos bb be:f32[] = mul ba bd bf:f32[] = mul be bc in (bf,) } policy=None prevent_cse=True ] p x in (y,) } policy=None prevent_cse=True ] 3.0 g in (o,) }
实际注意事项
当不同函数被分阶段送到 XLA 进行编译时,例如将 jax.jit
应用于包含 jax.grad
调用的函数时,XLA 将自动优化计算,包括决定何时计算或重新生成值。因此,在 jax.jit
下,通常不需要使用 jax.checkpoint
对不同函数进行检查点处理。XLA 将为您优化这些内容。
一个例外是在使用分阶段控制流(例如 jax.lax.scan
)时。跨多个控制流原语的自动编译器优化,例如在正向传播 scan
和相应的反向传播 scan
之间,通常不够彻底。因此,经常建议在传递给 jax.lax.scan
的主体函数上使用 jax.checkpoint
。
例如,在大型Transformer 模型中的一个常见模式是将架构表达为通过层的jax.lax.scan
,以减少编译时间。也就是说,类比于一个简单的全连接网络,我们不是写像这样的代码:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer ParamsList = list[LayerParam] def net(params: ParamsList, x: jnp.ndarray): for W, b in params: x = jnp.maximum(jnp.dot(x, W) + b, 0.) return x
我们可以使用jax.lax.scan
来迭代层应用:
StackedWeights = jnp.ndarray # all weight matrices stacked together StackedBiases = jnp.ndarray # all bias vectors stacked together all_weights = jnp.stack([W for W, _ in params]) all_biases = jnp.stack([b for _, b in params]) def layer(x, W_b_pair): W, b = W_b_pair out = jnp.maximum(jnp.dot(x, W) + b, 0.) return out, None def net(all_weights, all_biases, x): x, _ = jax.lax.scan(layer, x, (all_weights, all_biases)) return x
这种逐层扫描的版本可以减少编译时间,但可能会阻碍一些编译器优化,导致梯度计算效率低下。为了缓解这个问题,我们可以在扫描函数上使用jax.checkpoint
:
from functools import partial @partial(jax.checkpoint, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable) def layer(x, W_b_pair): W, b = W_b_pair out = jnp.maximum(jnp.dot(x, W) + b, 0.) return out, None
通过这种方式使用jax.checkpoint
,我们手动控制 JAX 自动微分在前向和反向传播之间保存的值,从而不依赖于 XLA 优化来为我们选择。
JAX 中文文档(九)(2)https://developer.aliyun.com/article/1559674