TorchScript 系列解读(一):初识 TorchScript

简介: PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。

640.gif

小伙伴们好呀,不久前我们推出了模型部署入门系列教程,受到了大家的一致好评,也收到了很多小伙伴的催更,后续教程正在准备中,将在不久后跟大家见面,敬请期待哦~

今天,我们又将开启新的 TorchScript 解读系列教程,带领大家玩转 PyTorch 模型部署。感兴趣的小伙伴一起往下看吧~

什么是 TorchScript



PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。PyTorch 源码 Readme 中还专门为此做了一张动态图:

640.gif

对研究员而言,PyTorch 能极大地提高想 idea、做实验、发论文的效率,是训练框架中的豪杰,但是它不适合部署。动态建图带来的优势对于性能要求更高的应用场景而言更像是缺点,非固定的网络结构给网络结构分析并进行优化带来了困难,多数参数都能以 Tensor 形式传输也让资源分配变成一件闹心的事。另外由于图是由 python 代码来构建的,一方面部署要依赖 python 环境,另一方面模型也毫无保密性可言。


而 TorchScript 就是为了解决这个问题而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。今天我们先简要地介绍一些 TorchScript 的功能,让大家有一个初步的认识,进阶的解读会陆续推出~


模型转换


作为模型部署的一个范式,通常我们都需要生成一个模型的中间表示(IR),这个 IR 拥有相对固定的图结构,所以更容易优化,让我们看一个例子:


import torch
from torchvision.models import resnet18
# 使用PyTorch model zoo中的resnet18作为例子
model = resnet18()
model.eval()
# 通过trace的方法生成IR需要一个输入样例
dummy_input = torch.rand(1, 3, 224, 224)
# IR生成
with torch.no_grad():
    jit_model = torch.jit.trace(model, dummy_input)

到这里就将 PyTorch 的模型转换成了 TorchScript 的 IR。这里我们使用了 trace 模式来生成 IR,所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图。关于 trace 的过程我们会在未来的分享中进行解读。


那么这个 IR 中到底都有些什么呢?我们可以可视化一下其中的 layer1 看看:

jit_layer1 = jit_model.layer1
print(jit_layer1.graph)
# graph(%self.6 : __torch__.torch.nn.modules.container.Sequential,
#       %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)):
#   %1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.6)
#   %2 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.6)
#   %6 : Tensor = prim::CallMethod[name="forward"](%2, %4)
#   %7 : Tensor = prim::CallMethod[name="forward"](%1, %6)
#   return (%7)

是不是有点摸不着头脑?TorchScript 有它自己对于 Graph 以及其中元素的定义,对于第一次接触的人来说可能比较陌生,但是没关系,我们还有另一种可视化方式:

print(jit_layer1.code)
# def forward(self,
#     argument_1: Tensor) -> Tensor:
#   _0 = getattr(self, "1")
#   _1 = (getattr(self, "0")).forward(argument_1, )
#   return (_0).forward(_1, )

没错,就是代码!TorchScript 的 IR 是可以还原成 python 代码的,如果你生成了一个 TorchScript 模型并且想知道它的内容对不对,那么可以通过这样的方式来做一些简单的检查。


刚才的例子中我们使用 trace 的方法生成 IR。除了 trace 之外,PyTorch 还提供了另一种生成 TorchScript 模型的方法:script。这种方式会直接解析网络定义的 python 代码,生成抽象语法树 AST,因此这种方法可以解决一些 trace 无法解决的问题,比如对 branch/loop 等数据流控制语句的建图。script 方式的建图有很多有趣的特性,会在未来的分享中做专题分析,敬请期待。


模型优化


聪明的同学可能发现了,上面的可视化中只有 resnet18 里 forward 的部分,其中的子模块信息是不是丢失了呢?如果没有丢失,那么怎么样才能确定子模块的内容是否正确呢?别担心,还记得我们说过 TorchScript 支持对网络的优化吗,这里我们就可以用一个 pass 解决这个问题:

# 调用inline pass,对graph做变换
torch._C._jit_pass_inline(jit_layer1.graph)
print(jit_layer1.code)
# def forward(self,
#     argument_1: Tensor) -> Tensor:
#   _0 = getattr(self, "1")
#   _1 = getattr(self, "0")
#   _2 = _1.bn2
#   _3 = _1.conv2
#   _4 = _1.bn1
#   input = torch._convolution(argument_1, _1.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _5 = _4.running_var
#   _6 = _4.running_mean
#   _7 = _4.bias
#   input0 = torch.batch_norm(input, _4.weight, _7, _6, _5, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input1 = torch.relu_(input0)
#   input2 = torch._convolution(input1, _3.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _8 = _2.running_var
#   _9 = _2.running_mean
#   _10 = _2.bias
#   out = torch.batch_norm(input2, _2.weight, _10, _9, _8, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input3 = torch.add_(out, argument_1, alpha=1)
#   input4 = torch.relu_(input3)
#   _11 = _0.bn2
#   _12 = _0.conv2
#   _13 = _0.bn1
#   input5 = torch._convolution(input4, _0.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _14 = _13.running_var
#   _15 = _13.running_mean
#   _16 = _13.bias
#   input6 = torch.batch_norm(input5, _13.weight, _16, _15, _14, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input7 = torch.relu_(input6)
#   input8 = torch._convolution(input7, _12.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
#   _17 = _11.running_var
#   _18 = _11.running_mean
#   _19 = _11.bias
#   out0 = torch.batch_norm(input8, _11.weight, _19, _18, _17, False, 0.10000000000000001, 1.0000000000000001e-05, True)
#   input9 = torch.add_(out0, input4, alpha=1)
#   return torch.relu_(input9)

这里我们就能看到卷积、batch_norm、relu 等熟悉的算子了。


上面代码中我们使用了一个名为 inline 的 pass,将所有子模块进行内联,这样我们就能看见更完整的推理代码。pass 是一个来源于编译原理的概念,一个 TorchScript 的 pass 会接收一个图,遍历图中所有元素进行某种变换,生成一个新的图。我们这里用到的 inline 起到的作用就是将模块调用展开,尽管这样做并不能直接影响执行效率,但是它其实是很多其他 pass 的基础。PyTorch 中定义了非常多的 pass 来解决各种优化任务,未来我们会做一些更详细的介绍。


序列化


不管是哪种方法创建的 TorchScript 都可以进行序列化,比如:

# 将模型序列化
jit_model.save('jit_model.pth')
# 加载序列化后的模型
jit_model = torch.jit.load('jit_model.pth')

序列化后的模型不再与 python 相关,可以被部署到各种平台上。


PyTorch 提供了可以用于 TorchScript 模型推理的 c++ API,序列化后的模型终于可以不依赖 python 进行推理了:

// 加载生成的torchscript模型
auto module = torch::jit::load('jit_model.pth');
// 根据任务需求读取数据
std::vector<torch::jit::IValue> inputs = ...;
// 计算推理结果
auto output = module.forward(inputs).toTensor();


与其他组件的关系



与 torch.onnx 的关系

640.jpg640.png



ONNX 是业界广泛使用的一种神经网络中间表示,PyTorch 自然也对 ONNX 提供了支持。torch.onnx.export 函数可以帮助我们把 PyTorch 模型转换成 ONNX 模型,这个函数会使用 trace 的方式记录 PyTorch 的推理过程。聪明的同学可能已经想到了,没错,ONNX 的导出,使用的正是 TorchScript 的 trace 工具。具体步骤如下:


1. 使用 trace 的方式先生成一个 TorchScipt 模型,如果你转换的本身就是 TorchScript 模型,则可以跳过这一步。

2. 使用许多 pass 对 1 中生成的模型进行变换,其中对 ONNX 导出最重要的一个 pass 就是ToONNX,这个 pass 会进行一个映射,将 TorchScript 中 prim、aten 空间下的算子映射到onnx空间下的算子。

3. 使用 ONNX 的 proto 格式对模型进行序列化,完成 ONNX 的导出。


关于 ONNX 导出的实现以及算子映射的方式将会在未来的分享中详细展开。

与 torch.fx 的关系


PyTorch1.9 开始添加了 torch.fx 工具,根据官方的介绍,它由符号追踪器 (symbolic tracer),中间表示(IR), Python 代码生成 (Python code generation) 等组件组成,实现了 python->python 的翻译。是不是和 TorchScript 看起来有点像?


其实他们之间联系不大,可以算是互相垂直的两个工具,为解决两个不同的任务而诞生。


TorchScript 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。


FX 的主要用途是进行 python->python 的翻译,它的 IR 中节点类型更简单,比如函数调用、属性提取等等,这样的 IR 学习成本更低更容易编辑。使用 FX 来编辑图通常是为了实现某种特定功能,比如给模型插入量化节点等,避免手动编辑网络造成的重复劳动。


这两个工具可以同时使用,比如使用 FX 工具编辑模型来让训练更便利、功能更强大;然后用 TorchScript 将模型加速部署到特定平台。


希望通过以上的分享,大家对 TorchScript 有了一个初步的认识,未来我们将会为大家带来更进阶的解读,欢迎大家持续关注。


文章来源:【OpenMMLab

2022-03-24 18:05


目录
相关文章
|
机器学习/深度学习 存储 并行计算
一篇就够:高性能推理引擎理论与实践 (TensorRT)
本文分享了关于 NVIDIA 推出的高性能的深度学习推理引擎 TensorRT 的背后理论知识和实践操作指南。
7796 6
一篇就够:高性能推理引擎理论与实践 (TensorRT)
|
Shell Linux Android开发
【Linux】【编译相关】execvp: /bin/sh: Argument list too long问题处理小结
【Linux】【编译相关】execvp: /bin/sh: Argument list too long问题处理小结
975 0
|
9月前
|
机器学习/深度学习 并行计算 PyTorch
torch.jit.script 与 torch.jit.trace
torch.jit.script 与 torch.jit.trace
433 0
|
11月前
|
机器学习/深度学习 数据可视化 PyTorch
模型推理加速系列 | 05: 推理加速格式TorchScript简介及其应用
本文主要TorchScript的基本概念及其在 C++ 中的使用
|
11月前
【ULP】什么是ULP
【ULP】什么是ULP
264 0
|
11月前
|
缓存 PyTorch 算法框架/工具
PyTorch distributed barrier 引发的陷阱
PyTorch distributed barrier 引发的陷阱
337 0
|
11月前
|
存储 PyTorch 算法框架/工具
一文读懂—Pytorch混合精度训练
一文读懂—Pytiorch混合精度训练
462 0
|
11月前
|
缓存 编译器 Linux
CMake链接第三方库
CMake链接第三方库
643 1
|
SQL 机器学习/深度学习 人工智能
王炸:这个GitHub 20000+ Star的OCR项目迎来四大重磅升级
王炸:这个GitHub 20000+ Star的OCR项目迎来四大重磅升级
550 0
|
机器学习/深度学习 自然语言处理 前端开发
TorchScript 解读:jit 中的 subgraph rewriter
现代的深度学习推理框架通常遵循编译器的范式,将模型的中间表示(IR)会分为两部分:包括与硬件、环境等无关的前端(frontend)以及针对特定环境的后端(backend),比如 TVM 的 Relay 和 tir 就是一个典型的例子。在 PyTorch 的 jit 中源码中,也包含前端与后端的部分(不过后端部分的更新似乎不是很频繁)。frontend 目录下有对 Graph IR 的定义,function_schema 的解析工具,以及将 torchscript 转换成 SSA(static single assignment)形式的转换器等等。
245 0
TorchScript 解读:jit 中的 subgraph rewriter