PyTorch 2.2 中文官方教程(十)(2)https://developer.aliyun.com/article/1482539
使用模块
要使用前向自动微分与 nn.Module
,在执行前向传递之前,将模型的参数替换为双张量。在撰写本文时,不可能创建双张量 nn.Parameter
。作为解决方法,必须将双张量注册为模块的非参数属性。
import torch.nn as nn model = nn.Linear(5, 5) input = torch.randn(16, 5) params = {name: p for name, p in model.named_parameters()} tangents = {name: torch.rand_like(p) for name, p in params.items()} with fwAD.dual_level(): for name, p in params.items(): delattr(model, name) setattr(model, name, fwAD.make_dual(p, tangents[name])) out = model(input) jvp = fwAD.unpack_dual(out).tangent
使用功能模块 API(beta)
使用前向自动微分的另一种方法是利用功能模块 API(也称为无状态模块 API)。
from torch.func import functional_call # We need a fresh module because the functional call requires the # the model to have parameters registered. model = nn.Linear(5, 5) dual_params = {} with fwAD.dual_level(): for name, p in params.items(): # Using the same ``tangents`` from the above section dual_params[name] = fwAD.make_dual(p, tangents[name]) out = functional_call(model, dual_params, input) jvp2 = fwAD.unpack_dual(out).tangent # Check our results assert torch.allclose(jvp, jvp2)
自定义 autograd 函数
自定义函数还支持前向模式自动微分。要创建支持前向模式自动微分的自定义函数,请注册 jvp()
静态方法。自定义函数可以支持前向和反向自动微分,但这不是强制的。有关更多信息,请参阅文档。
class Fn(torch.autograd.Function): @staticmethod def forward(ctx, foo): result = torch.exp(foo) # Tensors stored in ``ctx`` can be used in the subsequent forward grad # computation. ctx.result = result return result @staticmethod def jvp(ctx, gI): gO = gI * ctx.result # If the tensor stored in`` ctx`` will not also be used in the backward pass, # one can manually free it using ``del`` del ctx.result return gO fn = Fn.apply primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True) tangent = torch.randn(10, 10) with fwAD.dual_level(): dual_input = fwAD.make_dual(primal, tangent) dual_output = fn(dual_input) jvp = fwAD.unpack_dual(dual_output).tangent # It is important to use ``autograd.gradcheck`` to verify that your # custom autograd Function computes the gradients correctly. By default, # ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify # ``check_forward_ad=True`` to also check forward grads. If you did not # implement the backward formula for your function, you can also tell ``gradcheck`` # to skip the tests that require backward-mode AD by specifying # ``check_backward_ad=False``, ``check_undefined_grad=False``, and # ``check_batched_grad=False``. torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True, check_backward_ad=False, check_undefined_grad=False, check_batched_grad=False)
True
功能 API(beta)
我们还提供了 functorch 中用于计算雅可比向量积的更高级功能 API,根据您的用例,您可能会发现更简单使用。
功能 API 的好处是不需要理解或使用较低级别的双张量 API,并且可以将其与其他 functorch 转换(如 vmap)组合;缺点是它提供的控制较少。
请注意,本教程的其余部分将需要 functorch (github.com/pytorch/functorch
) 来运行。请在指定的链接找到安装说明。
import functorch as ft primal0 = torch.randn(10, 10) tangent0 = torch.randn(10, 10) primal1 = torch.randn(10, 10) tangent1 = torch.randn(10, 10) def fn(x, y): return x ** 2 + y ** 2 # Here is a basic example to compute the JVP of the above function. # The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the # computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape. primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1)) # ``functorch.jvp`` requires every primal to be associated with a tangent. # If we only want to associate certain inputs to `fn` with tangents, # then we'll need to create a new function that captures inputs without tangents: primal = torch.randn(10, 10) tangent = torch.randn(10, 10) y = torch.randn(10, 10) import functools new_fn = functools.partial(fn, y=y) primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:77: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3\. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html
使用功能 API 与模块
要使用 functorch.jvp
与 nn.Module
一起计算相对于模型参数的雅可比向量积,我们需要将 nn.Module
重新构建为一个接受模型参数和模块输入的函数。
model = nn.Linear(5, 5) input = torch.randn(16, 5) tangents = tuple([torch.rand_like(p) for p in model.parameters()]) # Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state # (``params`` and buffers) and returns a functional version of the model that # can be invoked like a function. # That is, the returned ``func`` can be invoked like # ``func(params, buffers, input)``. # ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API # that you saw previously and we're working on consolidating the two. func, params, buffers = ft.make_functional_with_buffers(model) # Because ``jvp`` requires every input to be associated with a tangent, we need to # create a new function that, when given the parameters, produces the output def func_params_only(params): return func(params, buffers, input) model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:104: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.make_functional_with_buffers is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3\. Please use torch.func.functional_call instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_functorch/deprecated.py:77: UserWarning: We've integrated functorch into PyTorch. As the final step of the integration, functorch.jvp is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3\. Please use torch.func.jvp instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/master/func.migrating.html
[0] en.wikipedia.org/wiki/Dual_number
脚本的总运行时间:(0 分钟 0.149 秒)
下载 Python 源代码:forward_ad_usage.py
下载 Jupyter 笔记本:forward_ad_usage.ipynb
雅可比矩阵、海森矩阵、hvp、vhp 等:组合函数转换
原文:
pytorch.org/tutorials/intermediate/jacobians_hessians.html
译者:飞龙
注意
点击这里下载完整的示例代码
计算雅可比矩阵或海森矩阵在许多非传统的深度学习模型中是有用的。使用 PyTorch 的常规自动微分 API(Tensor.backward()
,torch.autograd.grad
)高效地计算这些量是困难的(或者烦人的)。PyTorch 的 受 JAX 启发的 函数转换 API 提供了高效计算各种高阶自动微分量的方法。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵
import torch import torch.nn.functional as F from functools import partial _ = torch.manual_seed(0)
让我们从一个我们想要计算雅可比矩阵的函数开始。这是一个带有非线性激活的简单线性函数。
def predict(weight, bias, x): return F.linear(x, weight, bias).tanh()
让我们添加一些虚拟数据:一个权重、一个偏置和一个特征向量 x。
D = 16 weight = torch.randn(D, D) bias = torch.randn(D) x = torch.randn(D) # feature vector
让我们将 predict
视为一个将输入 x
从 R D → R D R^D \to R^DRD→RD 的函数。PyTorch Autograd 计算向量-雅可比乘积。为了计算这个 R D → R D R^D \to R^DRD→RD 函数的完整雅可比矩阵,我们将不得不逐行计算,每次使用一个不同的单位向量。
def compute_jac(xp): jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] for vec in unit_vectors] return torch.stack(jacobian_rows) xp = x.clone().requires_grad_() unit_vectors = torch.eye(D) jacobian = compute_jac(xp) print(jacobian.shape) print(jacobian[0]) # show first row
torch.Size([16, 16]) tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190, 0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
我们可以使用 PyTorch 的 torch.vmap
函数转换来消除循环并向量化计算,而不是逐行计算雅可比矩阵。我们不能直接将 vmap
应用于 torch.autograd.grad
;相反,PyTorch 提供了一个 torch.func.vjp
转换,与 torch.vmap
组合使用:
from torch.func import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) # let's confirm both methods compute the same result assert torch.allclose(ft_jacobian, jacobian)
在后续教程中,反向模式自动微分和 vmap
的组合将给我们提供每个样本的梯度。在本教程中,组合反向模式自动微分和 vmap
将给我们提供雅可比矩阵的计算!vmap
和自动微分转换的各种组合可以给我们提供不同的有趣量。
PyTorch 提供了 torch.func.jacrev
作为一个方便的函数,执行 vmap-vjp
组合来计算雅可比矩阵。jacrev
接受一个 argnums
参数,指定我们想要相对于哪个参数计算雅可比矩阵。
from torch.func import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) # Confirm by running the following: assert torch.allclose(ft_jacobian, jacobian)
让我们比较两种计算雅可比矩阵的方式的性能。函数转换版本要快得多(并且随着输出数量的增加而变得更快)。
一般来说,我们期望通过 vmap
的向量化可以帮助消除开销,并更好地利用硬件。
vmap
通过将外部循环下推到函数的原始操作中,以获得更好的性能。
让我们快速创建一个函数来评估性能,并处理微秒和毫秒的测量:
def get_perf(first, first_descriptor, second, second_descriptor): """takes torch.benchmark objects and compares delta of second vs first.""" faster = second.times[0] slower = first.times[0] gain = (slower-faster)/slower if gain < 0: gain *=-1 final_gain = gain*100 print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后进行性能比较:
from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_jac(xp)", globals=globals()) with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) no_vmap_timer = without_vmap.timeit(500) with_vmap_timer = with_vmap.timeit(500) print(no_vmap_timer) print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fc093552980> compute_jac(xp) 1.43 ms 1 measurement, 500 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7fc0914a7790> jacrev(predict, argnums=2)(weight, bias, x) 435.16 us 1 measurement, 500 runs , 1 thread
让我们通过我们的 get_perf
函数进行上述的相对性能比较:
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 69.4681 percent improvement with vmap
此外,很容易将问题转换过来,说我们想要计算模型参数(权重、偏置)的雅可比矩阵,而不是输入的雅可比矩阵
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵(jacrev
) vs 正向模式雅可比矩阵(jacfwd
)
我们提供了两个 API 来计算雅可比矩阵:jacrev
和 jacfwd
:
jacrev
使用反向模式自动微分。正如你在上面看到的,它是我们vjp
和vmap
转换的组合。jacfwd
使用正向模式自动微分。它是我们jvp
和vmap
转换的组合实现。
jacfwd
和 jacrev
可以互相替代,但它们具有不同的性能特征。
作为一个经验法则,如果你正在计算一个 R N → R M R^N \to R^MRN→RM 函数的雅可比矩阵,并且输出比输入要多得多(例如,M > N M > NM>N),那么首选 jacfwd
,否则使用 jacrev
。当然,这个规则也有例外,但以下是一个非严格的论证:
在反向模式 AD 中,我们逐行计算雅可比矩阵,而在正向模式 AD(计算雅可比向量积)中,我们逐列计算。雅可比矩阵有 M 行和 N 列,因此如果它在某个方向上更高或更宽,我们可能更喜欢处理较少行或列的方法。
from torch.func import jacrev, jacfwd
首先,让我们使用更多的输入进行基准测试:
Din = 32 Dout = 2048 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) # remember the general rule about taller vs wider... here we have a taller matrix: print(weight.shape) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32]) jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fc091472d10> jacfwd(predict, argnums=2)(weight, bias, x) 773.29 us 1 measurement, 500 runs , 1 thread jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fc0936e6b00> jacrev(predict, argnums=2)(weight, bias, x) 8.54 ms 1 measurement, 500 runs , 1 thread
然后进行相对基准测试:
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 1004.5112 percent improvement with jacrev
现在反过来 - 输出(M)比输入(N)更多:
Din = 2048 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals()) using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals()) jacfwd_timing = using_fwd.timeit(500) jacrev_timing = using_bwd.timeit(500) print(f'jacfwd time: {jacfwd_timing}') print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fc0915995a0> jacfwd(predict, argnums=2)(weight, bias, x) 7.15 ms 1 measurement, 500 runs , 1 thread jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fc091473d60> jacrev(predict, argnums=2)(weight, bias, x) 533.13 us 1 measurement, 500 runs , 1 thread
以及相对性能比较:
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 1241.8207 percent improvement with jacfwd
使用 functorch.hessian 进行 Hessian 计算
我们提供了一个方便的 API 来计算 Hessian:torch.func.hessiani
。Hessians 是雅可比矩阵的雅可比矩阵(或偏导数的偏导数,也称为二阶导数)。
这表明可以简单地组合 functorch 雅可比变换来计算 Hessian。实际上,在内部,hessian(f)
就是jacfwd(jacrev(f))
。
注意:为了提高性能:根据您的模型,您可能还希望使用jacfwd(jacfwd(f))
或jacrev(jacrev(f))
来计算 Hessian,利用上述关于更宽还是更高矩阵的经验法则。
from torch.func import hessian # lets reduce the size in order not to overwhelm Colab. Hessians require # significant memory: Din = 512 Dout = 32 weight = torch.randn(Dout, Din) bias = torch.randn(Dout) x = torch.randn(Din) hess_api = hessian(predict, argnums=2)(weight, bias, x) hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x) hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证无论是使用 Hessian API 还是使用jacfwd(jacfwd())
,我们都会得到相同的结果。
torch.allclose(hess_api, hess_fwdfwd)
True
批处理雅可比矩阵和批处理 Hessian
在上面的例子中,我们一直在操作单个特征向量。在某些情况下,您可能希望对一批输出相对于一批输入进行雅可比矩阵的计算。也就是说,给定形状为(B, N)
的输入批次和一个从R N → R M R^N \to R^MRN→RM的函数,我们希望得到形状为(B, M, N)
的雅可比矩阵。
使用vmap
是最简单的方法:
batch_size = 64 Din = 31 Dout = 33 weight = torch.randn(Dout, Din) print(f"weight shape = {weight.shape}") bias = torch.randn(Dout) x = torch.randn(batch_size, Din) compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0)) batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果您有一个从(B, N) -> (B, M)的函数,而且确定每个输入产生独立的输出,那么有时也可以通过对输出求和,然后计算该函数的雅可比矩阵来实现,而无需使用vmap
:
def predict_with_output_summed(weight, bias, x): return predict(weight, bias, x).sum(0) batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0) assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果您的函数是从R N → R M R^N \to R^MRN→RM,但输入是批处理的,您可以组合vmap
和jacrev
来计算批处理雅可比矩阵:
最后,批次 Hessian 矩阵的计算方式类似。最容易的方法是使用vmap
批处理 Hessian 计算,但在某些情况下,求和技巧也适用。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0)) batch_hess = compute_batch_hessian(weight, bias, x) batch_hess.shape
torch.Size([64, 33, 31, 31])
PyTorch 2.2 中文官方教程(十)(4)https://developer.aliyun.com/article/1482542