PyTorch 2.2 中文官方教程(十)(2)

简介: PyTorch 2.2 中文官方教程(十)

PyTorch 2.2 中文官方教程(十)(1)https://developer.aliyun.com/article/1482537

创建一个性能分析解释器

接下来,我们将创建一个从torch.fx.Interpreter继承的类。虽然symbolic_trace生成的GraphModule编译了 Python 代码,当您调用GraphModule时运行,但运行GraphModule的另一种方法是逐个执行Graph中的每个Node。这就是Interpreter提供的功能:它逐个解释图节点。

通过继承Interpreter,我们可以重写各种功能,并安装我们想要的分析行为。目标是有一个对象,我们可以将一个模型传递给它,调用模型 1 次或多次,然后获取关于模型和模型各部分在这些运行中花费多长时间的统计信息。

让我们定义我们的ProfilingInterpreter类:

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)
        # We are going to store away two things here:
        #
        # 1\. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2\. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.
    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val
    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.
    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val
    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.
    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)
        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])
        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)
        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers) 

注意

我们使用 Python 的time.time函数来获取墙钟时间戳并进行比较。这不是衡量性能的最准确方法,只会给我们一个一阶近似。我们仅仅出于演示目的使用这种简单的技术。

调查 ResNet18 的性能

我们现在可以使用ProfilingInterpreter来检查我们的 ResNet18 模型的性能特征;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True)) 
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.0058043              9.43883
call_module    conv1                            0.00556087             9.04297
call_module    layer4_0_conv2                   0.00342155             5.56404
call_module    layer4_1_conv2                   0.00325394             5.29148
call_module    layer4_1_conv1                   0.00316119             5.14066
call_module    layer1_0_conv2                   0.00267935             4.3571
call_module    layer1_1_conv1                   0.00267816             4.35516
call_module    layer3_0_conv2                   0.00267792             4.35477
call_module    layer3_1_conv1                   0.00261283             4.24893
call_module    layer3_1_conv2                   0.00259137             4.21403
call_module    layer1_0_conv1                   0.00256515             4.17138
call_module    layer2_1_conv1                   0.00249219             4.05274
call_module    layer2_1_conv2                   0.0024581              3.9973
call_module    layer2_0_conv2                   0.00242114             3.93721
call_module    layer1_1_conv2                   0.00241613             3.92906
call_module    layer4_0_conv1                   0.00203657             3.31183
call_module    layer3_0_conv1                   0.00165725             2.69498
call_module    layer2_0_conv1                   0.00164604             2.67676
call_module    bn1                              0.00133991             2.17894
call_module    layer2_0_downsample_0            0.000616312            1.00223
call_module    layer3_0_downsample_0            0.000507832            0.825825
call_module    layer4_0_downsample_0            0.000471115            0.766117
call_function  add                              0.00034976             0.568772
call_module    relu                             0.000216722            0.352429
call_function  add_1                            0.000201702            0.328004
call_module    fc                               0.000183105            0.297762
call_module    layer1_0_bn1                     0.000178337            0.290008
call_module    layer1_0_bn2                     0.000164032            0.266745
call_module    layer1_1_bn1                     0.000163794            0.266358
call_module    layer1_1_bn2                     0.000160933            0.261705
call_module    avgpool                          0.000149012            0.242319
call_module    layer2_1_bn2                     0.000141621            0.2303
call_module    layer2_0_downsample_1            0.000141382            0.229913
call_module    layer4_0_bn2                     0.000140429            0.228362
call_module    layer2_0_bn1                     0.000137806            0.224097
call_module    layer4_1_bn2                     0.000136852            0.222546
call_module    layer2_1_bn1                     0.000136137            0.221383
call_module    layer2_0_bn2                     0.000132799            0.215955
call_module    layer1_1_relu                    0.000128984            0.209752
call_function  add_2                            0.000127316            0.207038
call_module    layer3_1_bn1                     0.000127316            0.207038
call_module    layer3_0_downsample_1            0.0001266              0.205875
call_module    layer3_0_bn1                     0.000126362            0.205487
call_module    layer3_0_bn2                     0.000125647            0.204324
call_function  add_3                            0.000124454            0.202385
call_module    layer3_1_bn2                     0.000123978            0.20161
call_module    layer4_1_bn1                     0.000119686            0.194631
call_module    layer4_0_downsample_1            0.000118017            0.191917
call_module    layer4_0_bn1                     0.000117779            0.191529
call_module    layer1_0_relu                    0.000107288            0.17447
call_module    layer1_0_relu_1                  9.91821e-05            0.161288
call_module    layer1_1_relu_1                  9.63211e-05            0.156635
call_module    layer4_0_relu                    8.51154e-05            0.138413
call_function  add_5                            8.46386e-05            0.137637
call_module    layer4_1_relu                    8.44002e-05            0.13725
call_module    layer2_1_relu                    8.36849e-05            0.136087
call_function  add_4                            8.24928e-05            0.134148
call_module    layer2_0_relu                    8.10623e-05            0.131822
call_module    layer2_1_relu_1                  8.01086e-05            0.130271
call_module    layer2_0_relu_1                  7.96318e-05            0.129496
call_module    layer3_0_relu_1                  7.9155e-05             0.12872
call_module    layer4_0_relu_1                  7.7486e-05             0.126006
call_function  add_7                            7.7486e-05             0.126006
call_module    layer3_1_relu                    7.70092e-05            0.125231
call_function  add_6                            7.67708e-05            0.124843
call_module    layer4_1_relu_1                  7.67708e-05            0.124843
call_module    layer3_0_relu                    7.65324e-05            0.124455
call_module    layer3_1_relu_1                  7.10487e-05            0.115538
call_function  flatten                          4.3869e-05             0.0713388
placeholder    x                                2.59876e-05            0.0422605
output         output                           1.95503e-05            0.0317923 

这里有两件事情我们应该注意:

  • MaxPool2d占用了最多的时间。这是一个已知问题:github.com/pytorch/pytorch/issues/51393
  • BatchNorm2d 也占用了相当多的时间。我们可以继续这种思路,并在 Conv-BN Fusion with FX 教程中对其进行优化。

结论

正如我们所看到的,使用 FX,我们可以轻松地捕获 PyTorch 程序(甚至是我们没有源代码的程序!)以机器可解释的格式进行分析,比如我们在这里所做的性能分析。FX 为使用 PyTorch 程序开辟了一个充满可能性的世界。

最后,由于 FX 仍处于测试阶段,我们很乐意听取您对其使用的任何反馈意见。请随时使用 PyTorch 论坛(discuss.pytorch.org/)和问题跟踪器(github.com/pytorch/pytorch/issues)提供您可能有的任何反馈意见。

脚本的总运行时间:(0 分钟 0.374 秒)

下载 Python 源代码:fx_profiling_tutorial.py

下载 Jupyter 笔记本:fx_profiling_tutorial.ipynb

Sphinx-Gallery 生成的图库

前端 APIs

(beta)PyTorch 中的通道最后内存格式

原文:pytorch.org/tutorials/intermediate/memory_format_tutorial.html

译者:飞龙

协议:CC BY-NC-SA 4.0

注意

点击这里下载完整示例代码

作者Vitaly Fedyunin

什么是通道最后

通道最后的内存格式是在保留维度顺序的同时对 NCHW 张量进行排序的另一种方式。通道最后的张量以通道成为最密集的维度(即按像素存储图像)的方式进行排序。

例如,NCHW 张量的经典(连续)存储(在我们的情况下,是两个具有 3 个颜色通道的 4x4 图像)如下所示:

通道最后内存格式以不同的方式对数据进行排序:

Pytorch 通过利用现有的步幅结构来支持内存格式(并提供与现有模型(包括 eager、JIT 和 TorchScript)的向后兼容性)。例如,通道最后格式中的 10x3x16x16 批次将具有等于(768,1,48,3)的步幅。

通道最后内存格式仅适用于 4D NCHW 张量。

内存格式 API

以下是如何在连续和通道最后的内存格式之间转换张量的方法。

经典的 PyTorch 连续张量

import torch
N, C, H, W = 10, 3, 32, 32
x = torch.empty(N, C, H, W)
print(x.stride())  # Outputs: (3072, 1024, 32, 1) 
(3072, 1024, 32, 1) 

转换运算符

x = x.to(memory_format=torch.channels_last)
print(x.shape)  # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride())  # Outputs: (3072, 1, 96, 3) 
torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3) 

回到连续

x = x.to(memory_format=torch.contiguous_format)
print(x.stride())  # Outputs: (3072, 1024, 32, 1) 
(3072, 1024, 32, 1) 

备选选项

x = x.contiguous(memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

格式检查

print(x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True 
True 

tocontiguous这两个 API 之间存在一些细微差别。我们建议在明确转换张量的内存格式时坚持使用to

对于一般情况,这两个 API 的行为是相同的。然而,在特殊情况下,对于大小为NCHW的 4D 张量,当C==1H==1 && W==1时,只有to会生成适当的步幅以表示通道最后的内存格式。

这是因为在上述两种情况中,张量的内存格式是模糊的,即大小为N1HW的连续张量在内存存储中既是contiguous又是通道最后的。因此,它们已被视为给定内存格式的is_contiguous,因此contiguous调用变为无操作,并且不会更新步幅。相反,to会在尺寸为 1 的维度上重新调整张量的步幅,以正确表示预期的内存格式。

special_x = torch.empty(4, 1, 4, 4)
print(special_x.is_contiguous(memory_format=torch.channels_last))  # Outputs: True
print(special_x.is_contiguous(memory_format=torch.contiguous_format))  # Outputs: True 
True
True 

相同的情况也适用于显式置换 API permute。在可能发生模糊的特殊情况下,permute不能保证生成适当携带预期内存格式的步幅。我们建议使用to并明确指定内存格式,以避免意外行为。

另外需要注意的是,在极端情况下,当三个非批量维度都等于1时(C==1 && H==1 && W==1),当前的实现无法将张量标记为通道最后的内存格式。

创建为通道最后

x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

clone 保留内存格式

y = x.clone()
print(y.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

tocudafloat … 保留内存格式

if torch.cuda.is_available():
    y = x.cuda()
    print(y.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

empty_like*_like运算符保留内存格式

y = torch.empty_like(x)
print(y.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

逐点运算符保留内存格式

z = x + y
print(z.stride())  # Outputs: (3072, 1, 96, 3) 
(3072, 1, 96, 3) 

使用cudnn后端的ConvBatchnorm模块支持通道最后(仅适用于 cuDNN >= 7.6)。卷积模块,与二进制逐点运算符不同,通道最后是主导的内存格式。如果所有输入都在连续的内存格式中,操作符将以连续的内存格式生成输出。否则,输出将以通道最后的内存格式生成。

if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
    model = torch.nn.Conv2d(8, 4, 3).cuda().half()
    model = model.to(memory_format=torch.channels_last)  # Module parameters need to be channels last
    input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
    input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
    out = model(input)
    print(out.is_contiguous(memory_format=torch.channels_last))  # Outputs: True 
True 

当输入张量到达不支持通道最后的操作符时,内核应自动应用置换以恢复输入张量上的连续性。这会引入开销并停止通道最后的内存格式传播。尽管如此,它保证了正确的输出。

性能收益

Channels last 内存格式优化在 GPU 和 CPU 上都可用。在 GPU 上,观察到 NVIDIA 硬件上具有 Tensor Cores 支持的运行在降低精度(torch.float16)时,性能增益最显著。我们能够在使用‘AMP(自动混合精度)’训练脚本时,通过 Channels last 实现超过 22%的性能增益,同时利用了由 NVIDIA 提供的 AMP github.com/NVIDIA/apex

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000) 

通过传递--channels-last true允许在 Channels last 格式中运行模型,观察到 22%的性能增益。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000) 

以下模型列表完全支持 Channels last,并在 Volta 设备上显示 8%-35%的性能增益:alexnetmnasnet0_5mnasnet0_75mnasnet1_0mnasnet1_3mobilenet_v2resnet101resnet152resnet18resnet34resnet50resnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0shufflenet_v2_x1_5shufflenet_v2_x2_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

以下模型列表完全支持 Channels last,并在 Intel® Xeon® Ice Lake(或更新)CPU 上显示 26%-76%的性能增益:alexnetdensenet121densenet161densenet169googlenetinception_v3mnasnet0_5mnasnet1_0resnet101resnet152resnet18resnet34resnet50resnext101_32x8dresnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

转换现有模型

Channels last 支持不仅限于现有模型,因为任何模型都可以转换为 Channels last 并在输入(或某些权重)正确格式化后通过图形传播格式。

# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last)  # Replace with your model
# Need to be done for every input
input = input.to(memory_format=torch.channels_last)  # Replace with your input
output = model(input) 

然而,并非所有运算符都完全转换为支持 Channels last(通常返回连续的输出)。在上面发布的示例中,不支持 Channels last 的层将停止内存格式传播。尽管如此,由于我们已将模型转换为 Channels last 格式,这意味着每个卷积层,其 4 维权重在 Channels last 内存格式中,将恢复 Channels last 内存格式并从更快的内核中受益。

但是,不支持 Channels last 的运算符会通过置换引入开销。可选地,您可以调查并识别模型中不支持 Channels last 的运算符,如果要改进转换模型的性能。

这意味着您需要根据支持的运算符列表github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support验证所使用的运算符列表,或者在急切执行模式中引入内存格式检查并运行您的模型。

在运行以下代码后,如果运算符的输出与输入的内存格式不匹配,运算符将引发异常。

def contains_cl(args):
    for t in args:
        if isinstance(t, torch.Tensor):
            if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
                return True
        elif isinstance(t, list) or isinstance(t, tuple):
            if contains_cl(list(t)):
                return True
    return False
def print_inputs(args, indent=""):
    for t in args:
        if isinstance(t, torch.Tensor):
            print(indent, t.stride(), t.shape, t.device, t.dtype)
        elif isinstance(t, list) or isinstance(t, tuple):
            print(indent, type(t))
            print_inputs(list(t), indent=indent + "    ")
        else:
            print(indent, t)
def check_wrapper(fn):
    name = fn.__name__
    def check_cl(*args, **kwargs):
        was_cl = contains_cl(args)
        try:
            result = fn(*args, **kwargs)
        except Exception as e:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            print("-------------------")
            raise e
        failed = False
        if was_cl:
            if isinstance(result, torch.Tensor):
                if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
                    print(
                        "`{}` got channels_last input, but output is not channels_last:".format(name),
                        result.shape,
                        result.stride(),
                        result.device,
                        result.dtype,
                    )
                    failed = True
        if failed and True:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            raise Exception("Operator `{}` lost channels_last property".format(name))
        return result
    return check_cl
old_attrs = dict()
def attribute(m):
    old_attrs[m] = dict()
    for i in dir(m):
        e = getattr(m, i)
        exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
        if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
            try:
                old_attrs[m][i] = e
                setattr(m, i, check_wrapper(e))
            except Exception as e:
                print(i)
                print(e)
attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch) 

如果发现一个不支持 Channels last 张量的运算符,并且您想要贡献,可以随时使用以下开发者指南github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators

以下代码是为了恢复 torch 的属性。

for (m, attrs) in old_attrs.items():
    for (k, v) in attrs.items():
        setattr(m, k, v) 

需要做的工作

还有许多事情要做,例如:

  • 解决N1HWNC11张量的歧义;
  • 测试分布式训练支持;
  • 提高运算符覆盖率。

如果您有反馈和/或改进建议,请通过创建一个问题让我们知道。

脚本的总运行时间:(0 分钟 0.038 秒)

下载 Python 源代码:memory_format_tutorial.py

下载 Jupyter 笔记本:memory_format_tutorial.ipynb

Sphinx-Gallery 生成的图库

前向模式自动微分(Beta)

原文:pytorch.org/tutorials/intermediate/forward_ad_usage.html

译者:飞龙

协议:CC BY-NC-SA 4.0

注意

点击这里下载完整示例代码

本教程演示了如何使用前向模式自动微分来计算方向导数(或等效地,雅可比向量积)。

下面的教程仅使用版本 >= 1.11(或夜间构建)中才可用的一些 API。

还要注意,前向模式自动微分目前处于 beta 阶段。API 可能会发生变化,操作符覆盖仍然不完整。

基本用法

与反向模式自动微分不同,前向模式自动微分在前向传递过程中急切地计算梯度。我们可以使用前向模式自动微分来计算方向导数,方法是在执行前向传递之前,将我们的输入与另一个表示方向导数方向(或等效地,雅可比向量积中的 v)的张量相关联。当一个称为“原始”的输入与一个称为“切向”的“方向”张量相关联时,所得到的新张量对象被称为“双张量”,因为它与双重数的连接[0]。

在执行前向传递时,如果任何输入张量是双张量,则会执行额外的计算以传播函数的“敏感性”。

import torch
import torch.autograd.forward_ad as fwAD
primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)
def fn(x, y):
    return x ** 2 + y ** 2
# All forward AD computation must be performed in the context of
# a ``dual_level`` context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
with fwAD.dual_level():
    # To create a dual tensor we associate a tensor, which we call the
    # primal with another tensor of the same size, which we call the tangent.
    # If the layout of the tangent is different from that of the primal,
    # The values of the tangent are copied into a new tensor with the same
    # metadata as the primal. Otherwise, the tangent itself is used as-is.
    #
    # It is also important to note that the dual tensor created by
    # ``make_dual`` is a view of the primal.
    dual_input = fwAD.make_dual(primal, tangent)
    assert fwAD.unpack_dual(dual_input).tangent is tangent
    # To demonstrate the case where the copy of the tangent happens,
    # we pass in a tangent with a layout different from that of the primal
    dual_input_alt = fwAD.make_dual(primal, tangent.T)
    assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent
    # Tensors that do not have an associated tangent are automatically
    # considered to have a zero-filled tangent of the same shape.
    plain_tensor = torch.randn(10, 10)
    dual_output = fn(dual_input, plain_tensor)
    # Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``
    # as attributes
    jvp = fwAD.unpack_dual(dual_output).tangent
assert fwAD.unpack_dual(dual_output).tangent is None 

PyTorch 2.2 中文官方教程(十)(3)https://developer.aliyun.com/article/1482541

相关文章
|
6天前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(二十)(4)
PyTorch 2.2 中文官方教程(二十)
31 0
PyTorch 2.2 中文官方教程(二十)(4)
|
6天前
|
PyTorch 算法框架/工具 并行计算
PyTorch 2.2 中文官方教程(二十)(3)
PyTorch 2.2 中文官方教程(二十)
49 0
|
6天前
|
Android开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(2)
PyTorch 2.2 中文官方教程(二十)
49 0
PyTorch 2.2 中文官方教程(二十)(2)
|
6天前
|
iOS开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(1)
PyTorch 2.2 中文官方教程(二十)
51 0
PyTorch 2.2 中文官方教程(二十)(1)
|
6天前
|
PyTorch 算法框架/工具 并行计算
PyTorch 2.2 中文官方教程(十九)(4)
PyTorch 2.2 中文官方教程(十九)
32 0
|
6天前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十九)(3)
PyTorch 2.2 中文官方教程(十九)
30 0
PyTorch 2.2 中文官方教程(十九)(3)
|
6天前
|
异构计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十九)(2)
PyTorch 2.2 中文官方教程(十九)
62 0
PyTorch 2.2 中文官方教程(十九)(2)
|
6天前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十九)(1)
PyTorch 2.2 中文官方教程(十九)
81 1
PyTorch 2.2 中文官方教程(十九)(1)
|
6天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(4)
PyTorch 2.2 中文官方教程(十八)
55 1
|
6天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
36 1
PyTorch 2.2 中文官方教程(十八)(3)

相关实验场景

更多