PyTorch 2.2 中文官方教程(五)(2)https://developer.aliyun.com/article/1482499
空间变换网络教程
原文:
pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
译者:飞龙
注意
点击这里下载完整的示例代码
在本教程中,您将学习如何使用称为空间变换网络的视觉注意机制来增强您的网络。您可以在DeepMind 论文中阅读更多关于空间变换网络的信息。
空间变换网络是可微分注意力的泛化,适用于任何空间变换。空间变换网络(简称 STN)允许神经网络学习如何对输入图像执行空间变换,以增强模型的几何不变性。例如,它可以裁剪感兴趣的区域,缩放和校正图像的方向。这可能是一个有用的机制,因为 CNN 对旋转和缩放以及更一般的仿射变换不具有不变性。
STN 最好的一点是能够简单地将其插入到任何现有的 CNN 中,几乎不需要修改。
# License: BSD # Author: Ghassen Hamrouni import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np plt.ion() # interactive mode
<contextlib.ExitStack object at 0x7fc0914a7160>
加载数据
在本文中,我们使用经典的 MNIST 数据集进行实验。使用标准的卷积网络增强了空间变换网络。
from six.moves import urllib opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] urllib.request.install_opener(opener) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Training dataset train_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=4) # Test dataset test_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=4)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/9912422 [00:00<?, ?it/s] 100%|##########| 9912422/9912422 [00:00<00:00, 367023704.91it/s] Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/28881 [00:00<?, ?it/s] 100%|##########| 28881/28881 [00:00<00:00, 47653695.45it/s] Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/1648877 [00:00<?, ?it/s] 100%|##########| 1648877/1648877 [00:00<00:00, 343101225.21it/s] Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/4542 [00:00<?, ?it/s] 100%|##########| 4542/4542 [00:00<00:00, 48107395.88it/s] Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
描绘空间变换网络
空间变换网络归结为三个主要组件:
- 本地化网络是一个普通的 CNN,用于回归变换参数。这个变换从未从这个数据集中明确学习,相反,网络自动学习增强全局准确性的空间变换。
- 网格生成器生成与输出图像中的每个像素对应的输入图像中的坐标网格。
- 采样器使用变换的参数并将其应用于输入图像。
注意
我们需要包含 affine_grid 和 grid_sample 模块的最新版本的 PyTorch。
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) # Spatial transformer localization-network self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) ) # Initialize the weights/bias with identity transformation self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # Spatial transformer network forward function def stn(self, x): xs = self.localization(x) xs = xs.view(-1, 10 * 3 * 3) theta = self.fc_loc(xs) theta = theta.view(-1, 2, 3) grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x def forward(self, x): # transform the input x = self.stn(x) # Perform the usual forward pass x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) model = Net().to(device) • 56
训练模型
现在,让我们使用 SGD 算法来训练模型。网络以监督方式学习分类任务。同时,模型以端到端的方式自动学习 STN。
optimizer = optim.SGD(model.parameters(), lr=0.01) def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 500 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) # # A simple test procedure to measure the STN performances on MNIST. # def test(): with torch.no_grad(): model.eval() test_loss = 0 correct = 0 for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) # sum up batch loss test_loss += F.nll_loss(output, target, size_average=False).item() # get the index of the max log-probability pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' .format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) • 39
可视化 STN 结果
现在,我们将检查我们学习的视觉注意机制的结果。
我们定义了一个小的辅助函数,以便在训练过程中可视化变换。
def convert_image_np(inp): """Convert a Tensor to numpy image.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp # We want to visualize the output of the spatial transformers layer # after the training, we visualize a batch of input images and # the corresponding transformed batch using STN. def visualize_stn(): with torch.no_grad(): # Get a batch of training data data = next(iter(test_loader))[0].to(device) input_tensor = data.cpu() transformed_input_tensor = model.stn(data).cpu() in_grid = convert_image_np( torchvision.utils.make_grid(input_tensor)) out_grid = convert_image_np( torchvision.utils.make_grid(transformed_input_tensor)) # Plot the results side-by-side f, axarr = plt.subplots(1, 2) axarr[0].imshow(in_grid) axarr[0].set_title('Dataset Images') axarr[1].imshow(out_grid) axarr[1].set_title('Transformed Images') for epoch in range(1, 20 + 1): train(epoch) test() # Visualize the STN transformation on some input batch visualize_stn() plt.ioff() plt.show()
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/functional.py:4377: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0\. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details. /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/functional.py:4316: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0\. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details. Train Epoch: 1 [0/60000 (0%)] Loss: 2.315648 Train Epoch: 1 [32000/60000 (53%)] Loss: 1.051217 /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead. Test set: Average loss: 0.2563, Accuracy: 9282/10000 (93%) Train Epoch: 2 [0/60000 (0%)] Loss: 0.544514 Train Epoch: 2 [32000/60000 (53%)] Loss: 0.312879 Test set: Average loss: 0.1506, Accuracy: 9569/10000 (96%) Train Epoch: 3 [0/60000 (0%)] Loss: 0.408838 Train Epoch: 3 [32000/60000 (53%)] Loss: 0.221301 Test set: Average loss: 0.1207, Accuracy: 9634/10000 (96%) Train Epoch: 4 [0/60000 (0%)] Loss: 0.400088 Train Epoch: 4 [32000/60000 (53%)] Loss: 0.166533 Test set: Average loss: 0.1176, Accuracy: 9634/10000 (96%) Train Epoch: 5 [0/60000 (0%)] Loss: 0.274838 Train Epoch: 5 [32000/60000 (53%)] Loss: 0.223936 Test set: Average loss: 0.2812, Accuracy: 9136/10000 (91%) Train Epoch: 6 [0/60000 (0%)] Loss: 0.411823 Train Epoch: 6 [32000/60000 (53%)] Loss: 0.114000 Test set: Average loss: 0.0697, Accuracy: 9790/10000 (98%) Train Epoch: 7 [0/60000 (0%)] Loss: 0.066122 Train Epoch: 7 [32000/60000 (53%)] Loss: 0.208773 Test set: Average loss: 0.0660, Accuracy: 9799/10000 (98%) Train Epoch: 8 [0/60000 (0%)] Loss: 0.201612 Train Epoch: 8 [32000/60000 (53%)] Loss: 0.081877 Test set: Average loss: 0.0672, Accuracy: 9798/10000 (98%) Train Epoch: 9 [0/60000 (0%)] Loss: 0.077046 Train Epoch: 9 [32000/60000 (53%)] Loss: 0.147858 Test set: Average loss: 0.0645, Accuracy: 9811/10000 (98%) Train Epoch: 10 [0/60000 (0%)] Loss: 0.086268 Train Epoch: 10 [32000/60000 (53%)] Loss: 0.185868 Test set: Average loss: 0.0678, Accuracy: 9794/10000 (98%) Train Epoch: 11 [0/60000 (0%)] Loss: 0.138696 Train Epoch: 11 [32000/60000 (53%)] Loss: 0.119381 Test set: Average loss: 0.0663, Accuracy: 9795/10000 (98%) Train Epoch: 12 [0/60000 (0%)] Loss: 0.145220 Train Epoch: 12 [32000/60000 (53%)] Loss: 0.204023 Test set: Average loss: 0.0592, Accuracy: 9808/10000 (98%) Train Epoch: 13 [0/60000 (0%)] Loss: 0.118743 Train Epoch: 13 [32000/60000 (53%)] Loss: 0.100721 Test set: Average loss: 0.0643, Accuracy: 9801/10000 (98%) Train Epoch: 14 [0/60000 (0%)] Loss: 0.066341 Train Epoch: 14 [32000/60000 (53%)] Loss: 0.107528 Test set: Average loss: 0.0551, Accuracy: 9838/10000 (98%) Train Epoch: 15 [0/60000 (0%)] Loss: 0.022679 Train Epoch: 15 [32000/60000 (53%)] Loss: 0.055676 Test set: Average loss: 0.0474, Accuracy: 9862/10000 (99%) Train Epoch: 16 [0/60000 (0%)] Loss: 0.102644 Train Epoch: 16 [32000/60000 (53%)] Loss: 0.165537 Test set: Average loss: 0.0574, Accuracy: 9839/10000 (98%) Train Epoch: 17 [0/60000 (0%)] Loss: 0.280918 Train Epoch: 17 [32000/60000 (53%)] Loss: 0.206559 Test set: Average loss: 0.0533, Accuracy: 9846/10000 (98%) Train Epoch: 18 [0/60000 (0%)] Loss: 0.052316 Train Epoch: 18 [32000/60000 (53%)] Loss: 0.082710 Test set: Average loss: 0.0484, Accuracy: 9865/10000 (99%) Train Epoch: 19 [0/60000 (0%)] Loss: 0.083889 Train Epoch: 19 [32000/60000 (53%)] Loss: 0.121432 Test set: Average loss: 0.0522, Accuracy: 9839/10000 (98%) Train Epoch: 20 [0/60000 (0%)] Loss: 0.067540 Train Epoch: 20 [32000/60000 (53%)] Loss: 0.024880 Test set: Average loss: 0.0868, Accuracy: 9773/10000 (98%)
脚本的总运行时间:(3 分钟 30.487 秒)
下载 Python 源代码:spatial_transformer_tutorial.py
下载 Jupyter 笔记本:spatial_transformer_tutorial.ipynb
优化用于部署的 Vision Transformer 模型
原文:
pytorch.org/tutorials/beginner/vt_tutorial.html
译者:飞龙
注意
点击此处下载完整示例代码
Vision Transformer 模型应用了引入自自然语言处理的最先进的基于注意力的 Transformer 模型,以实现各种最先进(SOTA)结果,用于计算机视觉任务。Facebook Data-efficient Image Transformers DeiT是在 ImageNet 上进行图像分类训练的 Vision Transformer 模型。
在本教程中,我们将首先介绍 DeiT 是什么以及如何使用它,然后逐步介绍脚本化、量化、优化和在 iOS 和 Android 应用程序中使用模型的完整步骤。我们还将比较量化、优化和非量化、非优化模型的性能,并展示在各个步骤中应用量化和优化对模型的好处。
什么是 DeiT
自 2012 年深度学习兴起以来,卷积神经网络(CNNs)一直是图像分类的主要模型,但 CNNs 通常需要数亿张图像进行训练才能实现 SOTA 结果。DeiT 是一个视觉 Transformer 模型,需要更少的数据和计算资源进行训练,以与领先的 CNNs 竞争执行图像分类,这是由 DeiT 的两个关键组件实现的:
- 数据增强模拟在更大数据集上进行训练;
- 原生蒸馏允许 Transformer 网络从 CNN 的输出中学习。
DeiT 表明 Transformer 可以成功应用于计算机视觉任务,且对数据和资源的访问有限。有关 DeiT 的更多详细信息,请参见存储库和论文。
使用 DeiT 对图像进行分类
请按照 DeiT 存储库中的README.md
中的详细信息来对图像进行分类,或者进行快速测试,首先安装所需的软件包:
pip install torch torchvision timm pandas requests
要在 Google Colab 中运行,请通过运行以下命令安装依赖项:
!pip install timm pandas requests
然后运行下面的脚本:
from PIL import Image import torch import timm import requests import torchvision.transforms as transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD print(torch.__version__) # should be 1.8.0 model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) model.eval() transform = transforms.Compose([ transforms.Resize(256, interpolation=3), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), ]) img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw) img = transform(img)[None,] out = model(img) clsidx = torch.argmax(out) print(clsidx.item())
2.2.0+cu121 Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/jenkins/.cache/torch/hub/main.zip /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning: Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning: Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning: Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning: Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning: Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning: Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning: Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning: Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384\. This is because the name being registered conflicts with an existing name. Please check if this is not expected. Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth 0%| | 0.00/330M [00:00<?, ?B/s] 4%|3 | 12.4M/330M [00:00<00:02, 130MB/s] 7%|7 | 24.7M/330M [00:00<00:02, 110MB/s] 11%|#1 | 36.8M/330M [00:00<00:02, 117MB/s] 15%|#4 | 49.2M/330M [00:00<00:02, 121MB/s] 19%|#8 | 62.2M/330M [00:00<00:02, 127MB/s] 23%|##3 | 76.7M/330M [00:00<00:01, 135MB/s] 27%|##7 | 90.6M/330M [00:00<00:01, 139MB/s] 32%|###1 | 106M/330M [00:00<00:01, 144MB/s] 36%|###6 | 119M/330M [00:00<00:01, 125MB/s] 40%|###9 | 132M/330M [00:01<00:01, 122MB/s] 45%|####4 | 147M/330M [00:01<00:01, 132MB/s] 49%|####8 | 162M/330M [00:01<00:01, 138MB/s] 53%|#####3 | 176M/330M [00:01<00:01, 142MB/s] 58%|#####7 | 190M/330M [00:01<00:01, 144MB/s] 62%|######2 | 205M/330M [00:01<00:00, 147MB/s] 67%|######6 | 220M/330M [00:01<00:00, 149MB/s] 71%|####### | 234M/330M [00:01<00:00, 148MB/s] 76%|#######5 | 250M/330M [00:01<00:00, 155MB/s] 81%|########1 | 268M/330M [00:01<00:00, 162MB/s] 86%|########6 | 285M/330M [00:02<00:00, 168MB/s] 91%|#########1| 302M/330M [00:02<00:00, 172MB/s] 97%|#########6| 319M/330M [00:02<00:00, 175MB/s] 100%|##########| 330M/330M [00:02<00:00, 147MB/s] 269
输出应该是 269,根据 ImageNet 类索引到标签文件,对应timber wolf, grey wolf, gray wolf, Canis lupus
。
现在我们已经验证了可以使用 DeiT 模型对图像进行分类,让我们看看如何修改模型以便在 iOS 和 Android 应用程序上运行。
脚本化 DeiT
要在移动设备上使用模型,我们首先需要对模型进行脚本化。查看脚本化和优化配方以获取快速概述。运行下面的代码将 DeiT 模型转换为 TorchScript 格式,以便在移动设备上运行。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True) model.eval() scripted_model = torch.jit.script(model) scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/jenkins/.cache/torch/hub/facebookresearch_deit_main
生成的脚本模型文件fbdeit_scripted.pt
大小约为 346MB。
量化 DeiT
为了显著减小训练模型的大小,同时保持推理准确性大致相同,可以对模型应用量化。由于 DeiT 中使用的 Transformer 模型,我们可以轻松地将动态量化应用于模型,因为动态量化最适用于 LSTM 和 Transformer 模型(有关更多详细信息,请参见此处)。
现在运行下面的代码:
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference. backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook model.qconfig = torch.quantization.get_default_qconfig(backend) torch.backends.quantized.engine = backend quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8) scripted_quantized_model = torch.jit.script(quantized_model) scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/ao/quantization/observer.py:220: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
这将生成脚本化和量化版本的模型fbdeit_quantized_scripted.pt
,大小约为 89MB,比 346MB 的非量化模型大小减少了 74%!
您可以使用scripted_quantized_model
生成相同的推理结果:
out = scripted_quantized_model(img) clsidx = torch.argmax(out) print(clsidx.item()) # The same output 269 should be printed
269
优化 DeiT
在将量化和脚本化模型应用于移动设备之前的最后一步是对其进行优化:
from torch.utils.mobile_optimizer import optimize_for_mobile optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model) optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
生成的fbdeit_optimized_scripted_quantized.pt
文件的大小与量化、脚本化但非优化模型的大小大致相同。推理结果保持不变。
out = optimized_scripted_quantized_model(img) clsidx = torch.argmax(out) print(clsidx.item()) # Again, the same output 269 should be printed
269
使用 Lite 解释器
要查看 Lite 解释器可以导致多少模型大小减小和推理速度提升,请创建模型的精简版本。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl") ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
尽管精简模型的大小与非精简版本相当,但在移动设备上运行精简版本时,预计会加快推理速度。
比较推理速度
要查看四个模型的推理速度差异 - 原始模型、脚本模型、量化和脚本模型、优化的量化和脚本模型 - 运行下面的代码:
with torch.autograd.profiler.profile(use_cuda=False) as prof1: out = model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof2: out = scripted_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof3: out = scripted_quantized_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof4: out = optimized_scripted_quantized_model(img) with torch.autograd.profiler.profile(use_cuda=False) as prof5: out = ptl(img) print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000)) print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000)) print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000)) print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000)) print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 123.27ms scripted model: 111.89ms scripted & quantized model: 129.99ms scripted & quantized & optimized model: 129.94ms lite model: 120.00ms
在 Google Colab 上运行的结果是:
original model: 1236.69ms scripted model: 1226.72ms scripted & quantized model: 593.19ms scripted & quantized & optimized model: 598.01ms lite model: 600.72ms
以下结果总结了每个模型的推理时间以及相对于原始模型的每个模型的百分比减少。
import pandas as pd import numpy as np df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']}) df = pd.concat([df, pd.DataFrame([ ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"], ["{:.2f}ms".format(prof2.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof3.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof4.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)], ["{:.2f}ms".format(prof5.self_cpu_time_total/1000), "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]], columns=['Inference Time', 'Reduction'])], axis=1) print(df) """ Model Inference Time Reduction 0 original model 1236.69ms 0% 1 scripted model 1226.72ms 0.81% 2 scripted & quantized model 593.19ms 52.03% 3 scripted & quantized & optimized model 598.01ms 51.64% 4 lite model 600.72ms 51.43% """
Model ... Reduction 0 original model ... 0% 1 scripted model ... 9.23% 2 scripted & quantized model ... -5.45% 3 scripted & quantized & optimized model ... -5.41% 4 lite model ... 2.65% [5 rows x 3 columns] '\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'
了解更多
- Facebook 数据高效图像变换器
- 使用 ImageNet 和 MNIST 在 iOS 上的 Vision Transformer
- 使用 ImageNet 和 MNIST 在 Android 上的 Vision Transformer
脚本的总运行时间:(0 分钟 20.779 秒)
下载 Python 源代码:vt_tutorial.py
下载 Jupyter 笔记本:vt_tutorial.ipynb
PyTorch 2.2 中文官方教程(五)(4)https://developer.aliyun.com/article/1482504