使用整体追踪分析的追踪差异
原文:
pytorch.org/tutorials/beginner/hta_trace_diff_tutorial.html
译者:飞龙
作者: Anupam Bhatnagar
有时,用户需要识别由代码更改导致的 PyTorch 操作符和 CUDA 内核的变化。为了支持这一需求,HTA 提供了一个追踪比较功能。该功能允许用户输入两组追踪文件,第一组可以被视为控制组,第二组可以被视为测试组,类似于 A/B 测试。TraceDiff
类提供了比较追踪之间差异的函数以及可视化这些差异的功能。特别是,用户可以找到每个组中添加和删除的操作符和内核,以及每个操作符/内核的频率和操作符/内核所花费的累积时间。
TraceDiff类具有以下方法:
- compare_traces: 比较两组追踪中 CPU 操作符和 GPU 内核的频率和总持续时间。
- ops_diff: 获取已被以下操作符和内核删除的操作符和内核:
- 添加到测试追踪中并在控制追踪中不存在
- 从测试追踪中删除并存在于控制追踪中
- 在测试追踪中增加并存在于控制追踪中
- 在测试追踪中减少并存在于控制追踪中
- 在两组追踪中未更改
最后两种方法可用于使用compare_traces
方法的输出可视化 CPU 操作符和 GPU 内核的频率和持续时间的各种变化。
例如,可以计算出频率增加最多的前十个操作符如下:
df = compare_traces_output.sort_values(by="diff_counts", ascending=False).head(10) TraceDiff.visualize_counts_diff(df)
类似地,可以计算出持续时间变化最大的前十个操作符如下:
df = compare_traces_output.sort_values(by="diff_duration", ascending=False) # The duration differerence can be overshadowed by the "ProfilerStep", # so we can filter it out to show the trend of other operators. df = df.loc[~df.index.str.startswith("ProfilerStep")].head(10) TraceDiff.visualize_duration_diff(df)
有关此功能的详细示例,请参阅存储库的示例文件夹中的trace_diff_demo notebook。
代码转换与 FX
(beta)在 FX 中构建一个卷积/批量归一化融合器
原文:
pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html
译者:飞龙
注意
点击这里下载完整示例代码
作者:Horace He
在本教程中,我们将使用 FX,一个用于 PyTorch 可组合函数转换的工具包,执行以下操作:
- 在数据依赖关系中查找卷积/批量归一化的模式。
- 对于在 1)中找到的模式,将批量归一化统计数据折叠到卷积权重中。
请注意,此优化仅适用于处于推理模式的模型(即 mode.eval())
我们将构建存在于此处的融合器:github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
首先,让我们导入一些模块(我们稍后将在代码中使用所有这些)。
from typing import Type, Dict, Any, Tuple, Iterable import copy import torch.fx as fx import torch import torch.nn as nn
对于本教程,我们将创建一个由卷积和批量归一化组成的模型。请注意,这个模型有一些棘手的组件 - 一些卷积/批量归一化模式隐藏在 Sequential 中,一个BatchNorms
被包装在另一个模块中。
class WrappedBatchNorm(nn.Module): def __init__(self): super().__init__() self.mod = nn.BatchNorm2d(1) def forward(self, x): return self.mod(x) class M(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) self.bn1 = nn.BatchNorm2d(1) self.conv2 = nn.Conv2d(1, 1, 1) self.nested = nn.Sequential( nn.BatchNorm2d(1), nn.Conv2d(1, 1, 1), ) self.wrapped = WrappedBatchNorm() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.conv2(x) x = self.nested(x) x = self.wrapped(x) return x model = M() model.eval()
融合卷积与批量归一化
尝试在 PyTorch 中自动融合卷积和批量归一化的主要挑战之一是 PyTorch 没有提供一种轻松访问计算图的方法。FX 通过符号跟踪实际调用的操作来解决这个问题,这样我们就可以通过前向调用、嵌套在 Sequential 模块中或包装在用户定义模块中来跟踪计算。
traced_model = torch.fx.symbolic_trace(model) print(traced_model.graph)
这给我们提供了模型的图形表示。请注意,顺序内部的模块以及包装的模块都已内联到图中。这是默认的抽象级别,但可以由通道编写者配置。更多信息请参阅 FX 概述pytorch.org/docs/master/fx.html#module-torch.fx
融合卷积与批量归一化
与其他一些融合不同,卷积与批量归一化的融合不需要任何新的运算符。相反,在推理期间,批量归一化由逐点加法和乘法组成,这些操作可以“烘烤”到前面卷积的权重中。这使我们能够完全从我们的模型中删除批量归一化!阅读nenadmarkus.com/p/fusing-batchnorm-and-conv/
获取更多详细信息。这里的代码是从github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
复制的,以便更清晰。
def fuse_conv_bn_eval(conv, bn): """ Given a conv Module `A` and an batch_norm module `B`, returns a conv module `C` such that C(x) == B(A(x)) in inference mode. """ assert(not (conv.training or bn.training)), "Fusion only for eval!" fused_conv = copy.deepcopy(conv) fused_conv.weight, fused_conv.bias = \ fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) return fused_conv def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): if conv_b is None: conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
FX Fusion Pass
现在我们有了我们的计算图以及融合卷积和批量归一化的方法,剩下的就是迭代 FX 图并应用所需的融合。
def _parent_name(target : str) -> Tuple[str, str]: """ Splits a ``qualname`` into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) """ *parent, name = target.rsplit('.', 1) return parent[0] if parent else '', name def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): assert(isinstance(node.target, str)) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, new_module) def fuse(model: torch.nn.Module) -> torch.nn.Module: model = copy.deepcopy(model) # The first step of most FX passes is to symbolically trace our model to # obtain a `GraphModule`. This is a representation of our original model # that is functionally identical to our original model, except that we now # also have a graph representation of our forward pass. fx_model: fx.GraphModule = fx.symbolic_trace(model) modules = dict(fx_model.named_modules()) # The primary representation for working with FX are the `Graph` and the # `Node`. Each `GraphModule` has a `Graph` associated with it - this # `Graph` is also what generates `GraphModule.code`. # The `Graph` itself is represented as a list of `Node` objects. Thus, to # iterate through all of the operations in our graph, we iterate over each # `Node` in our `Graph`. for node in fx_model.graph.nodes: # The FX IR contains several types of nodes, which generally represent # call sites to modules, functions, or methods. The type of node is # determined by `Node.op`. if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it. continue # For call sites, `Node.target` represents the module/function/method # that's being called. Here, we check `Node.target` to see if it's a # batch norm module, and then check `Node.args[0].target` to see if the # input `Node` is a convolution. if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d: if len(node.args[0].users) > 1: # Output of conv is used by other nodes continue conv = modules[node.args[0].target] bn = modules[node.target] fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) # As we've folded the batch nor into the conv, we need to replace all uses # of the batch norm with the conv. node.replace_all_uses_with(node.args[0]) # Now that all uses of the batch norm have been replaced, we can # safely remove the batch norm. fx_model.graph.erase_node(node) fx_model.graph.lint() # After we've modified our graph, we need to recompile our graph in order # to keep the generated code in sync. fx_model.recompile() return fx_model
注意
为了演示目的,我们在这里进行了一些简化,比如只匹配 2D 卷积。查看github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
以获取更可用的通道。
测试我们的融合通道
现在我们可以在初始的玩具模型上运行这个融合通道,并验证我们的结果是相同的。此外,我们可以打印出我们融合模型的代码,并验证是否还有批量归一化。
fused_model = fuse(model) print(fused_model.code) inp = torch.randn(5, 1, 1, 1) torch.testing.assert_allclose(fused_model(inp), model(inp))
在 ResNet18 上对我们的融合进行基准测试
我们可以在像 ResNet18 这样的较大模型上测试我们的融合通道,看看这个通道如何提高推理性能。
import torchvision.models as models import time rn18 = models.resnet18() rn18.eval() inp = torch.randn(10, 3, 224, 224) output = rn18(inp) def benchmark(model, iters=20): for _ in range(10): model(inp) begin = time.time() for _ in range(iters): model(inp) return str(time.time()-begin) fused_rn18 = fuse(rn18) print("Unfused time: ", benchmark(rn18)) print("Fused time: ", benchmark(fused_rn18))
正如我们之前看到的,我们的 FX 转换的输出是(“torchscriptable”)PyTorch 代码,我们可以轻松地jit.script
输出,尝试进一步提高性能。通过这种方式,我们的 FX 模型转换与 TorchScript 组合在一起,没有任何问题。
jit_rn18 = torch.jit.script(fused_rn18) print("jit time: ", benchmark(jit_rn18)) ############ # Conclusion # ---------- # As we can see, using FX we can easily write static graph transformations on # PyTorch code. # # Since FX is still in beta, we would be happy to hear any # feedback you have about using it. Please feel free to use the # PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker # (https://github.com/pytorch/pytorch/issues) to provide any feedback # you might have.
脚本的总运行时间:(0 分钟 0.000 秒)
下载 Python 源代码:fx_conv_bn_fuser.py
下载 Jupyter 笔记本:fx_conv_bn_fuser.ipynb
(beta)使用 FX 构建一个简单的 CPU 性能分析器
原文:
pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html
译者:飞龙
注意
点击这里下载完整的示例代码
作者:James Reed
在本教程中,我们将使用 FX 来执行以下操作:
- 以一种我们可以检查和收集关于代码结构和执行的统计信息的方式捕获 PyTorch Python 代码
- 构建一个小类,作为一个简单的性能“分析器”,收集关于模型各部分的运行时统计信息。
在本教程中,我们将使用 torchvision ResNet18 模型进行演示。
import torch import torch.fx import torchvision.models as models rn18 = models.resnet18() rn18.eval()
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) )
现在我们有了我们的模型,我们想要更深入地检查其性能。也就是说,在以下调用中,模型的哪些部分花费时间最长?
input = torch.randn(5, 3, 224, 224) output = rn18(input)
回答这个问题的常见方法是浏览程序源代码,在程序的各个点添加收集时间戳的代码,并比较这些时间戳之间的差异,以查看这些时间戳之间的区域需要多长时间。
这种技术当然适用于 PyTorch 代码,但如果我们不必复制模型代码并进行编辑,尤其是我们没有编写的代码(比如这个 torchvision 模型),那将更好。相反,我们将使用 FX 自动化这个“仪器化”过程,而无需修改任何源代码。
首先,让我们解决一些导入问题(我们稍后将在代码中使用所有这些)。
import statistics, tabulate, time from typing import Any, Dict, List from torch.fx import Interpreter
注意
tabulate
是一个外部库,不是 PyTorch 的依赖项。我们将使用它来更轻松地可视化性能数据。请确保您已从您喜欢的 Python 软件包源安装了它。
使用符号跟踪捕获模型
接下来,我们将使用 FX 的符号跟踪机制来捕获我们模型的定义,以便我们可以操作和检查它。
traced_rn18 = torch.fx.symbolic_trace(rn18) print(traced_rn18.graph)
graph(): %x : torch.Tensor [num_users=1] = placeholder[target=x] %conv1 : [num_users=1] = call_moduletarget=conv1, kwargs = {}) %bn1 : [num_users=1] = call_moduletarget=bn1, kwargs = {}) %relu : [num_users=1] = call_moduletarget=relu, kwargs = {}) %maxpool : [num_users=2] = call_moduletarget=maxpool, kwargs = {}) %layer1_0_conv1 : [num_users=1] = call_moduletarget=layer1.0.conv1, kwargs = {}) %layer1_0_bn1 : [num_users=1] = call_moduletarget=layer1.0.bn1, kwargs = {}) %layer1_0_relu : [num_users=1] = call_moduletarget=layer1.0.relu, kwargs = {}) %layer1_0_conv2 : [num_users=1] = call_moduletarget=layer1.0.conv2, kwargs = {}) %layer1_0_bn2 : [num_users=1] = call_moduletarget=layer1.0.bn2, kwargs = {}) %add : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer1_0_relu_1 : [num_users=2] = call_moduletarget=layer1.0.relu, kwargs = {}) %layer1_1_conv1 : [num_users=1] = call_moduletarget=layer1.1.conv1, kwargs = {}) %layer1_1_bn1 : [num_users=1] = call_moduletarget=layer1.1.bn1, kwargs = {}) %layer1_1_relu : [num_users=1] = call_moduletarget=layer1.1.relu, kwargs = {}) %layer1_1_conv2 : [num_users=1] = call_moduletarget=layer1.1.conv2, kwargs = {}) %layer1_1_bn2 : [num_users=1] = call_moduletarget=layer1.1.bn2, kwargs = {}) %add_1 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer1_1_relu_1 : [num_users=2] = call_moduletarget=layer1.1.relu, kwargs = {}) %layer2_0_conv1 : [num_users=1] = call_moduletarget=layer2.0.conv1, kwargs = {}) %layer2_0_bn1 : [num_users=1] = call_moduletarget=layer2.0.bn1, kwargs = {}) %layer2_0_relu : [num_users=1] = call_moduletarget=layer2.0.relu, kwargs = {}) %layer2_0_conv2 : [num_users=1] = call_moduletarget=layer2.0.conv2, kwargs = {}) %layer2_0_bn2 : [num_users=1] = call_moduletarget=layer2.0.bn2, kwargs = {}) %layer2_0_downsample_0 : [num_users=1] = call_moduletarget=layer2.0.downsample.0, kwargs = {}) %layer2_0_downsample_1 : [num_users=1] = call_moduletarget=layer2.0.downsample.1, kwargs = {}) %add_2 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer2_0_relu_1 : [num_users=2] = call_moduletarget=layer2.0.relu, kwargs = {}) %layer2_1_conv1 : [num_users=1] = call_moduletarget=layer2.1.conv1, kwargs = {}) %layer2_1_bn1 : [num_users=1] = call_moduletarget=layer2.1.bn1, kwargs = {}) %layer2_1_relu : [num_users=1] = call_moduletarget=layer2.1.relu, kwargs = {}) %layer2_1_conv2 : [num_users=1] = call_moduletarget=layer2.1.conv2, kwargs = {}) %layer2_1_bn2 : [num_users=1] = call_moduletarget=layer2.1.bn2, kwargs = {}) %add_3 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer2_1_relu_1 : [num_users=2] = call_moduletarget=layer2.1.relu, kwargs = {}) %layer3_0_conv1 : [num_users=1] = call_moduletarget=layer3.0.conv1, kwargs = {}) %layer3_0_bn1 : [num_users=1] = call_moduletarget=layer3.0.bn1, kwargs = {}) %layer3_0_relu : [num_users=1] = call_moduletarget=layer3.0.relu, kwargs = {}) %layer3_0_conv2 : [num_users=1] = call_moduletarget=layer3.0.conv2, kwargs = {}) %layer3_0_bn2 : [num_users=1] = call_moduletarget=layer3.0.bn2, kwargs = {}) %layer3_0_downsample_0 : [num_users=1] = call_moduletarget=layer3.0.downsample.0, kwargs = {}) %layer3_0_downsample_1 : [num_users=1] = call_moduletarget=layer3.0.downsample.1, kwargs = {}) %add_4 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer3_0_relu_1 : [num_users=2] = call_moduletarget=layer3.0.relu, kwargs = {}) %layer3_1_conv1 : [num_users=1] = call_moduletarget=layer3.1.conv1, kwargs = {}) %layer3_1_bn1 : [num_users=1] = call_moduletarget=layer3.1.bn1, kwargs = {}) %layer3_1_relu : [num_users=1] = call_moduletarget=layer3.1.relu, kwargs = {}) %layer3_1_conv2 : [num_users=1] = call_moduletarget=layer3.1.conv2, kwargs = {}) %layer3_1_bn2 : [num_users=1] = call_moduletarget=layer3.1.bn2, kwargs = {}) %add_5 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer3_1_relu_1 : [num_users=2] = call_moduletarget=layer3.1.relu, kwargs = {}) %layer4_0_conv1 : [num_users=1] = call_moduletarget=layer4.0.conv1, kwargs = {}) %layer4_0_bn1 : [num_users=1] = call_moduletarget=layer4.0.bn1, kwargs = {}) %layer4_0_relu : [num_users=1] = call_moduletarget=layer4.0.relu, kwargs = {}) %layer4_0_conv2 : [num_users=1] = call_moduletarget=layer4.0.conv2, kwargs = {}) %layer4_0_bn2 : [num_users=1] = call_moduletarget=layer4.0.bn2, kwargs = {}) %layer4_0_downsample_0 : [num_users=1] = call_moduletarget=layer4.0.downsample.0, kwargs = {}) %layer4_0_downsample_1 : [num_users=1] = call_moduletarget=layer4.0.downsample.1, kwargs = {}) %add_6 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer4_0_relu_1 : [num_users=2] = call_moduletarget=layer4.0.relu, kwargs = {}) %layer4_1_conv1 : [num_users=1] = call_moduletarget=layer4.1.conv1, kwargs = {}) %layer4_1_bn1 : [num_users=1] = call_moduletarget=layer4.1.bn1, kwargs = {}) %layer4_1_relu : [num_users=1] = call_moduletarget=layer4.1.relu, kwargs = {}) %layer4_1_conv2 : [num_users=1] = call_moduletarget=layer4.1.conv2, kwargs = {}) %layer4_1_bn2 : [num_users=1] = call_moduletarget=layer4.1.bn2, kwargs = {}) %add_7 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %layer4_1_relu_1 : [num_users=1] = call_moduletarget=layer4.1.relu, kwargs = {}) %avgpool : [num_users=1] = call_moduletarget=avgpool, kwargs = {}) %flatten : [num_users=1] = call_functiontarget=torch.flatten, kwargs = {}) %fc : [num_users=1] = call_moduletarget=fc, kwargs = {}) return fc
这为我们提供了 ResNet18 模型的图形表示。图形由一系列相互连接的节点组成。每个节点代表 Python 代码中的调用点(无论是函数、模块还是方法),边缘(在每个节点上表示为args
和kwargs
)代表这些调用点之间传递的值。有关图形表示和 FX 的其余 API 的更多信息,请参阅 FX 文档pytorch.org/docs/master/fx.html
。
PyTorch 2.2 中文官方教程(十)(2)https://developer.aliyun.com/article/1482539