PyTorch 2.2 中文官方教程(十)(1)https://developer.aliyun.com/article/1482537
创建一个性能分析解释器
接下来,我们将创建一个从torch.fx.Interpreter
继承的类。虽然symbolic_trace
生成的GraphModule
编译了 Python 代码,当您调用GraphModule
时运行,但运行GraphModule
的另一种方法是逐个执行Graph
中的每个Node
。这就是Interpreter
提供的功能:它逐个解释图节点。
通过继承Interpreter
,我们可以重写各种功能,并安装我们想要的分析行为。目标是有一个对象,我们可以将一个模型传递给它,调用模型 1 次或多次,然后获取关于模型和模型各部分在这些运行中花费多长时间的统计信息。
让我们定义我们的ProfilingInterpreter
类:
class ProfilingInterpreter(Interpreter): def __init__(self, mod : torch.nn.Module): # Rather than have the user symbolically trace their model, # we're going to do it in the constructor. As a result, the # user can pass in any ``Module`` without having to worry about # symbolic tracing APIs gm = torch.fx.symbolic_trace(mod) super().__init__(gm) # We are going to store away two things here: # # 1\. A list of total runtimes for ``mod``. In other words, we are # storing away the time ``mod(...)`` took each time this # interpreter is called. self.total_runtime_sec : List[float] = [] # 2\. A map from ``Node`` to a list of times (in seconds) that # node took to run. This can be seen as similar to (1) but # for specific sub-parts of the model. self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {} ###################################################################### # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run`` # method is the top-level entry point for execution of the model. We will # want to intercept this so that we can record the total runtime of the # model. def run(self, *args) -> Any: # Record the time we started running the model t_start = time.time() # Run the model by delegating back into Interpreter.run() return_val = super().run(*args) # Record the time we finished running the model t_end = time.time() # Store the total elapsed time this model execution took in the # ``ProfilingInterpreter`` self.total_runtime_sec.append(t_end - t_start) return return_val ###################################################################### # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each # time it executes a single node. We will intercept this so that we # can measure and record the time taken for each individual call in # the model. def run_node(self, n : torch.fx.Node) -> Any: # Record the time we started running the op t_start = time.time() # Run the op by delegating back into Interpreter.run_node() return_val = super().run_node(n) # Record the time we finished running the op t_end = time.time() # If we don't have an entry for this node in our runtimes_sec # data structure, add one with an empty list value. self.runtimes_sec.setdefault(n, []) # Record the total elapsed time for this single invocation # in the runtimes_sec data structure self.runtimes_sec[n].append(t_end - t_start) return return_val ###################################################################### # Finally, we are going to define a method (one which doesn't override # any ``Interpreter`` method) that provides us a nice, organized view of # the data we have collected. def summary(self, should_sort : bool = False) -> str: # Build up a list of summary information for each node node_summaries : List[List[Any]] = [] # Calculate the mean runtime for the whole network. Because the # network may have been called multiple times during profiling, # we need to summarize the runtimes. We choose to use the # arithmetic mean for this. mean_total_runtime = statistics.mean(self.total_runtime_sec) # For each node, record summary statistics for node, runtimes in self.runtimes_sec.items(): # Similarly, compute the mean runtime for ``node`` mean_runtime = statistics.mean(runtimes) # For easier understanding, we also compute the percentage # time each node took with respect to the whole network. pct_total = mean_runtime / mean_total_runtime * 100 # Record the node's type, name of the node, mean runtime, and # percent runtime. node_summaries.append( [node.op, str(node), mean_runtime, pct_total]) # One of the most important questions to answer when doing performance # profiling is "Which op(s) took the longest?". We can make this easy # to see by providing sorting functionality in our summary view if should_sort: node_summaries.sort(key=lambda s: s[2], reverse=True) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers : List[str] = [ 'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime' ] return tabulate.tabulate(node_summaries, headers=headers)
注意
我们使用 Python 的time.time
函数来获取墙钟时间戳并进行比较。这不是衡量性能的最准确方法,只会给我们一个一阶近似。我们仅仅出于演示目的使用这种简单的技术。
调查 ResNet18 的性能
我们现在可以使用ProfilingInterpreter
来检查我们的 ResNet18 模型的性能特征;
interp = ProfilingInterpreter(rn18) interp.run(input) print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime ------------- --------------------- --------------------- ------------------- call_module maxpool 0.0058043 9.43883 call_module conv1 0.00556087 9.04297 call_module layer4_0_conv2 0.00342155 5.56404 call_module layer4_1_conv2 0.00325394 5.29148 call_module layer4_1_conv1 0.00316119 5.14066 call_module layer1_0_conv2 0.00267935 4.3571 call_module layer1_1_conv1 0.00267816 4.35516 call_module layer3_0_conv2 0.00267792 4.35477 call_module layer3_1_conv1 0.00261283 4.24893 call_module layer3_1_conv2 0.00259137 4.21403 call_module layer1_0_conv1 0.00256515 4.17138 call_module layer2_1_conv1 0.00249219 4.05274 call_module layer2_1_conv2 0.0024581 3.9973 call_module layer2_0_conv2 0.00242114 3.93721 call_module layer1_1_conv2 0.00241613 3.92906 call_module layer4_0_conv1 0.00203657 3.31183 call_module layer3_0_conv1 0.00165725 2.69498 call_module layer2_0_conv1 0.00164604 2.67676 call_module bn1 0.00133991 2.17894 call_module layer2_0_downsample_0 0.000616312 1.00223 call_module layer3_0_downsample_0 0.000507832 0.825825 call_module layer4_0_downsample_0 0.000471115 0.766117 call_function add 0.00034976 0.568772 call_module relu 0.000216722 0.352429 call_function add_1 0.000201702 0.328004 call_module fc 0.000183105 0.297762 call_module layer1_0_bn1 0.000178337 0.290008 call_module layer1_0_bn2 0.000164032 0.266745 call_module layer1_1_bn1 0.000163794 0.266358 call_module layer1_1_bn2 0.000160933 0.261705 call_module avgpool 0.000149012 0.242319 call_module layer2_1_bn2 0.000141621 0.2303 call_module layer2_0_downsample_1 0.000141382 0.229913 call_module layer4_0_bn2 0.000140429 0.228362 call_module layer2_0_bn1 0.000137806 0.224097 call_module layer4_1_bn2 0.000136852 0.222546 call_module layer2_1_bn1 0.000136137 0.221383 call_module layer2_0_bn2 0.000132799 0.215955 call_module layer1_1_relu 0.000128984 0.209752 call_function add_2 0.000127316 0.207038 call_module layer3_1_bn1 0.000127316 0.207038 call_module layer3_0_downsample_1 0.0001266 0.205875 call_module layer3_0_bn1 0.000126362 0.205487 call_module layer3_0_bn2 0.000125647 0.204324 call_function add_3 0.000124454 0.202385 call_module layer3_1_bn2 0.000123978 0.20161 call_module layer4_1_bn1 0.000119686 0.194631 call_module layer4_0_downsample_1 0.000118017 0.191917 call_module layer4_0_bn1 0.000117779 0.191529 call_module layer1_0_relu 0.000107288 0.17447 call_module layer1_0_relu_1 9.91821e-05 0.161288 call_module layer1_1_relu_1 9.63211e-05 0.156635 call_module layer4_0_relu 8.51154e-05 0.138413 call_function add_5 8.46386e-05 0.137637 call_module layer4_1_relu 8.44002e-05 0.13725 call_module layer2_1_relu 8.36849e-05 0.136087 call_function add_4 8.24928e-05 0.134148 call_module layer2_0_relu 8.10623e-05 0.131822 call_module layer2_1_relu_1 8.01086e-05 0.130271 call_module layer2_0_relu_1 7.96318e-05 0.129496 call_module layer3_0_relu_1 7.9155e-05 0.12872 call_module layer4_0_relu_1 7.7486e-05 0.126006 call_function add_7 7.7486e-05 0.126006 call_module layer3_1_relu 7.70092e-05 0.125231 call_function add_6 7.67708e-05 0.124843 call_module layer4_1_relu_1 7.67708e-05 0.124843 call_module layer3_0_relu 7.65324e-05 0.124455 call_module layer3_1_relu_1 7.10487e-05 0.115538 call_function flatten 4.3869e-05 0.0713388 placeholder x 2.59876e-05 0.0422605 output output 1.95503e-05 0.0317923
这里有两件事情我们应该注意:
MaxPool2d
占用了最多的时间。这是一个已知问题:github.com/pytorch/pytorch/issues/51393
- BatchNorm2d 也占用了相当多的时间。我们可以继续这种思路,并在 Conv-BN Fusion with FX 教程中对其进行优化。
结论
正如我们所看到的,使用 FX,我们可以轻松地捕获 PyTorch 程序(甚至是我们没有源代码的程序!)以机器可解释的格式进行分析,比如我们在这里所做的性能分析。FX 为使用 PyTorch 程序开辟了一个充满可能性的世界。
最后,由于 FX 仍处于测试阶段,我们很乐意听取您对其使用的任何反馈意见。请随时使用 PyTorch 论坛(discuss.pytorch.org/
)和问题跟踪器(github.com/pytorch/pytorch/issues
)提供您可能有的任何反馈意见。
脚本的总运行时间:(0 分钟 0.374 秒)
下载 Python 源代码:fx_profiling_tutorial.py
下载 Jupyter 笔记本:fx_profiling_tutorial.ipynb
前端 APIs
(beta)PyTorch 中的通道最后内存格式
原文:
pytorch.org/tutorials/intermediate/memory_format_tutorial.html
译者:飞龙
注意
点击这里下载完整示例代码
什么是通道最后
通道最后的内存格式是在保留维度顺序的同时对 NCHW 张量进行排序的另一种方式。通道最后的张量以通道成为最密集的维度(即按像素存储图像)的方式进行排序。
例如,NCHW 张量的经典(连续)存储(在我们的情况下,是两个具有 3 个颜色通道的 4x4 图像)如下所示:
通道最后内存格式以不同的方式对数据进行排序:
Pytorch 通过利用现有的步幅结构来支持内存格式(并提供与现有模型(包括 eager、JIT 和 TorchScript)的向后兼容性)。例如,通道最后格式中的 10x3x16x16 批次将具有等于(768,1,48,3)的步幅。
通道最后内存格式仅适用于 4D NCHW 张量。
内存格式 API
以下是如何在连续和通道最后的内存格式之间转换张量的方法。
经典的 PyTorch 连续张量
import torch N, C, H, W = 10, 3, 32, 32 x = torch.empty(N, C, H, W) print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
转换运算符
x = x.to(memory_format=torch.channels_last) print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved print(x.stride()) # Outputs: (3072, 1, 96, 3)
torch.Size([10, 3, 32, 32]) (3072, 1, 96, 3)
回到连续
x = x.to(memory_format=torch.contiguous_format) print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
备选选项
x = x.contiguous(memory_format=torch.channels_last) print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
格式检查
print(x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
to
和contiguous
这两个 API 之间存在一些细微差别。我们建议在明确转换张量的内存格式时坚持使用to
。
对于一般情况,这两个 API 的行为是相同的。然而,在特殊情况下,对于大小为NCHW
的 4D 张量,当C==1
或H==1 && W==1
时,只有to
会生成适当的步幅以表示通道最后的内存格式。
这是因为在上述两种情况中,张量的内存格式是模糊的,即大小为N1HW
的连续张量在内存存储中既是contiguous
又是通道最后的。因此,它们已被视为给定内存格式的is_contiguous
,因此contiguous
调用变为无操作,并且不会更新步幅。相反,to
会在尺寸为 1 的维度上重新调整张量的步幅,以正确表示预期的内存格式。
special_x = torch.empty(4, 1, 4, 4) print(special_x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Outputs: True
True True
相同的情况也适用于显式置换 API permute
。在可能发生模糊的特殊情况下,permute
不能保证生成适当携带预期内存格式的步幅。我们建议使用to
并明确指定内存格式,以避免意外行为。
另外需要注意的是,在极端情况下,当三个非批量维度都等于1
时(C==1 && H==1 && W==1
),当前的实现无法将张量标记为通道最后的内存格式。
创建为通道最后
x = torch.empty(N, C, H, W, memory_format=torch.channels_last) print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
clone
保留内存格式
y = x.clone() print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
to
,cuda
,float
… 保留内存格式
if torch.cuda.is_available(): y = x.cuda() print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
empty_like
,*_like
运算符保留内存格式
y = torch.empty_like(x) print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
逐点运算符保留内存格式
z = x + y print(z.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
使用cudnn
后端的Conv
,Batchnorm
模块支持通道最后(仅适用于 cuDNN >= 7.6)。卷积模块,与二进制逐点运算符不同,通道最后是主导的内存格式。如果所有输入都在连续的内存格式中,操作符将以连续的内存格式生成输出。否则,输出将以通道最后的内存格式生成。
if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603: model = torch.nn.Conv2d(8, 4, 3).cuda().half() model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True) input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16) out = model(input) print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
当输入张量到达不支持通道最后的操作符时,内核应自动应用置换以恢复输入张量上的连续性。这会引入开销并停止通道最后的内存格式传播。尽管如此,它保证了正确的输出。
性能收益
Channels last 内存格式优化在 GPU 和 CPU 上都可用。在 GPU 上,观察到 NVIDIA 硬件上具有 Tensor Cores 支持的运行在降低精度(torch.float16
)时,性能增益最显著。我们能够在使用‘AMP(自动混合精度)’训练脚本时,通过 Channels last 实现超过 22%的性能增益,同时利用了由 NVIDIA 提供的 AMP github.com/NVIDIA/apex
。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data
# opt_level = O2 # keep_batchnorm_fp32 = None <class 'NoneType'> # loss_scale = None <class 'NoneType'> # CUDNN VERSION: 7603 # => creating model 'resnet50' # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. # Defaults for this optimization level are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Processing user overrides (additional kwargs that are not None)... # After processing overrides, optimization options are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000) # Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000) # Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000) # Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000) # Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000) # Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000) # Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000) # Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)
通过传递--channels-last true
允许在 Channels last 格式中运行模型,观察到 22%的性能增益。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data
# opt_level = O2 # keep_batchnorm_fp32 = None <class 'NoneType'> # loss_scale = None <class 'NoneType'> # # CUDNN VERSION: 7603 # # => creating model 'resnet50' # Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. # # Defaults for this optimization level are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # Processing user overrides (additional kwargs that are not None)... # After processing overrides, optimization options are: # enabled : True # opt_level : O2 # cast_model_type : torch.float16 # patch_torch_functions : False # keep_batchnorm_fp32 : True # master_weights : True # loss_scale : dynamic # # Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000) # Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000) # Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000) # Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000) # Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000) # Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000) # Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000) # Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)
以下模型列表完全支持 Channels last,并在 Volta 设备上显示 8%-35%的性能增益:alexnet
,mnasnet0_5
,mnasnet0_75
,mnasnet1_0
,mnasnet1_3
,mobilenet_v2
,resnet101
,resnet152
,resnet18
,resnet34
,resnet50
,resnext50_32x4d
,shufflenet_v2_x0_5
,shufflenet_v2_x1_0
,shufflenet_v2_x1_5
,shufflenet_v2_x2_0
,squeezenet1_0
,squeezenet1_1
,vgg11
,vgg11_bn
,vgg13
,vgg13_bn
,vgg16
,vgg16_bn
,vgg19
,vgg19_bn
,wide_resnet101_2
,wide_resnet50_2
以下模型列表完全支持 Channels last,并在 Intel® Xeon® Ice Lake(或更新)CPU 上显示 26%-76%的性能增益:alexnet
,densenet121
,densenet161
,densenet169
,googlenet
,inception_v3
,mnasnet0_5
,mnasnet1_0
,resnet101
,resnet152
,resnet18
,resnet34
,resnet50
,resnext101_32x8d
,resnext50_32x4d
,shufflenet_v2_x0_5
,shufflenet_v2_x1_0
,squeezenet1_0
,squeezenet1_1
,vgg11
,vgg11_bn
,vgg13
,vgg13_bn
,vgg16
,vgg16_bn
,vgg19
,vgg19_bn
,wide_resnet101_2
,wide_resnet50_2
转换现有模型
Channels last 支持不仅限于现有模型,因为任何模型都可以转换为 Channels last 并在输入(或某些权重)正确格式化后通过图形传播格式。
# Need to be done once, after model initialization (or load) model = model.to(memory_format=torch.channels_last) # Replace with your model # Need to be done for every input input = input.to(memory_format=torch.channels_last) # Replace with your input output = model(input)
然而,并非所有运算符都完全转换为支持 Channels last(通常返回连续的输出)。在上面发布的示例中,不支持 Channels last 的层将停止内存格式传播。尽管如此,由于我们已将模型转换为 Channels last 格式,这意味着每个卷积层,其 4 维权重在 Channels last 内存格式中,将恢复 Channels last 内存格式并从更快的内核中受益。
但是,不支持 Channels last 的运算符会通过置换引入开销。可选地,您可以调查并识别模型中不支持 Channels last 的运算符,如果要改进转换模型的性能。
这意味着您需要根据支持的运算符列表github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support
验证所使用的运算符列表,或者在急切执行模式中引入内存格式检查并运行您的模型。
在运行以下代码后,如果运算符的输出与输入的内存格式不匹配,运算符将引发异常。
def contains_cl(args): for t in args: if isinstance(t, torch.Tensor): if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous(): return True elif isinstance(t, list) or isinstance(t, tuple): if contains_cl(list(t)): return True return False def print_inputs(args, indent=""): for t in args: if isinstance(t, torch.Tensor): print(indent, t.stride(), t.shape, t.device, t.dtype) elif isinstance(t, list) or isinstance(t, tuple): print(indent, type(t)) print_inputs(list(t), indent=indent + " ") else: print(indent, t) def check_wrapper(fn): name = fn.__name__ def check_cl(*args, **kwargs): was_cl = contains_cl(args) try: result = fn(*args, **kwargs) except Exception as e: print("`{}` inputs are:".format(name)) print_inputs(args) print("-------------------") raise e failed = False if was_cl: if isinstance(result, torch.Tensor): if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last): print( "`{}` got channels_last input, but output is not channels_last:".format(name), result.shape, result.stride(), result.device, result.dtype, ) failed = True if failed and True: print("`{}` inputs are:".format(name)) print_inputs(args) raise Exception("Operator `{}` lost channels_last property".format(name)) return result return check_cl old_attrs = dict() def attribute(m): old_attrs[m] = dict() for i in dir(m): e = getattr(m, i) exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"] if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e): try: old_attrs[m][i] = e setattr(m, i, check_wrapper(e)) except Exception as e: print(i) print(e) attribute(torch.Tensor) attribute(torch.nn.functional) attribute(torch)
如果发现一个不支持 Channels last 张量的运算符,并且您想要贡献,可以随时使用以下开发者指南github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators
。
以下代码是为了恢复 torch 的属性。
for (m, attrs) in old_attrs.items(): for (k, v) in attrs.items(): setattr(m, k, v)
需要做的工作
还有许多事情要做,例如:
- 解决
N1HW
和NC11
张量的歧义; - 测试分布式训练支持;
- 提高运算符覆盖率。
如果您有反馈和/或改进建议,请通过创建一个问题让我们知道。
脚本的总运行时间:(0 分钟 0.038 秒)
下载 Python 源代码:memory_format_tutorial.py
下载 Jupyter 笔记本:memory_format_tutorial.ipynb
前向模式自动微分(Beta)
原文:
pytorch.org/tutorials/intermediate/forward_ad_usage.html
译者:飞龙
注意
点击这里下载完整示例代码
本教程演示了如何使用前向模式自动微分来计算方向导数(或等效地,雅可比向量积)。
下面的教程仅使用版本 >= 1.11(或夜间构建)中才可用的一些 API。
还要注意,前向模式自动微分目前处于 beta 阶段。API 可能会发生变化,操作符覆盖仍然不完整。
基本用法
与反向模式自动微分不同,前向模式自动微分在前向传递过程中急切地计算梯度。我们可以使用前向模式自动微分来计算方向导数,方法是在执行前向传递之前,将我们的输入与另一个表示方向导数方向(或等效地,雅可比向量积中的 v
)的张量相关联。当一个称为“原始”的输入与一个称为“切向”的“方向”张量相关联时,所得到的新张量对象被称为“双张量”,因为它与双重数的连接[0]。
在执行前向传递时,如果任何输入张量是双张量,则会执行额外的计算以传播函数的“敏感性”。
import torch import torch.autograd.forward_ad as fwAD primal = torch.randn(10, 10) tangent = torch.randn(10, 10) def fn(x, y): return x ** 2 + y ** 2 # All forward AD computation must be performed in the context of # a ``dual_level`` context. All dual tensors created in such a context # will have their tangents destroyed upon exit. This is to ensure that # if the output or intermediate results of this computation are reused # in a future forward AD computation, their tangents (which are associated # with this computation) won't be confused with tangents from the later # computation. with fwAD.dual_level(): # To create a dual tensor we associate a tensor, which we call the # primal with another tensor of the same size, which we call the tangent. # If the layout of the tangent is different from that of the primal, # The values of the tangent are copied into a new tensor with the same # metadata as the primal. Otherwise, the tangent itself is used as-is. # # It is also important to note that the dual tensor created by # ``make_dual`` is a view of the primal. dual_input = fwAD.make_dual(primal, tangent) assert fwAD.unpack_dual(dual_input).tangent is tangent # To demonstrate the case where the copy of the tangent happens, # we pass in a tangent with a layout different from that of the primal dual_input_alt = fwAD.make_dual(primal, tangent.T) assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent # Tensors that do not have an associated tangent are automatically # considered to have a zero-filled tangent of the same shape. plain_tensor = torch.randn(10, 10) dual_output = fn(dual_input, plain_tensor) # Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent`` # as attributes jvp = fwAD.unpack_dual(dual_output).tangent assert fwAD.unpack_dual(dual_output).tangent is None
PyTorch 2.2 中文官方教程(十)(3)https://developer.aliyun.com/article/1482541