torch.jit.script 与 torch.jit.trace

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: torch.jit.script 与 torch.jit.trace

torch.jit.script 与 torch.jit.trace

torch.jit.script torch.jit.trace 是 PyTorch 中用于将模型转换为脚本或跟踪模型执行的工具。

它们是 PyTorch 的即时编译(Just-in-Time Compilation)模块的一部分,用于提高模型的执行效率并支持模型的部署。

torch.jit.script

torch.jit.script 是将模型转换为脚本的函数。

它接受一个 PyTorch 模型作为输入,并将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。

这种转换的好处是可以减少模型执行过程中的开销,因为它消除了 Python 解释器的开销。

示例:

import torch
# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)
    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x
model = MyModel()
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
# 调用
output = scripted_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')

torch.jit.trace

torch.jit.trace 是跟踪模型执行的函数。

它接受一个模型和一个示例输入,并记录模型在给定输入上的执行过程,然后返回一个跟踪模型。

跟踪模型可以看作是一个具有相同功能的脚本模型,但它还保留了原始模型的动态特性,可以使用更多高级特性,如动态图和控制流。

示例:

import torch
# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)
    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x
model = MyModel()
# 将模型转换为Torch脚本模块
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output = traced_model(torch.randn(1, 3, 32, 32))
print(output)
# 保存模型
torch.jit.save(traced_model, './model/Test/traced_model.pth')

注意

由于 torch.jit.trace 方法只跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。

torch.jit.save

使用 torch.jit.scripttorch.jit.trace 转换后的模块对象可以直接用于推理,也可以使用 torch.jit.save 方法将其保存到文件中,以便在需要时加载模型。

torch.jit.load

使用 torch.jit.load 函数可以加载 PyTorch 模型,该函数可以接收一个模型文件路径或一个文件对象作为输入参数。具体步骤如下:

  • 加载模型文件:
import torch
model = torch.jit.load("model.pt")

这将加载名为 model.pt 的模型文件。

  • 加载模型文件并指定设备:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("model.pt", map_location=device)

这将加载名为 model.pt 的模型文件,并将其放置在可用的 CUDA 设备上。

  • 加载模型文件并使用 eval 模式:
import torch
model = torch.jit.load("model.pt")
model.eval()

这将加载名为 model.pt 的模型文件,并将其转换为评估模式。

注意:

如果模型使用了特定的设备,例如 CUDA,那么在加载模型时需要确保该设备可用。如果设备不可用,则需要使用 map_location 参数将模型映射到可用的设备上。

Code

import torch
# 定义模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.fc = torch.nn.Linear(64 * 8 * 8, 10)
    def forward(self, x):
        x = self.conv(x)
        x = torch.nn.functional.relu(x)
        x = x.view(-1, 64 * 8 * 8)
        x = self.fc(x)
        return x
model = MyModel()
print(model)
# 将模型转换为Torch脚本模块
scripted_model = torch.jit.script(model)
traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32))
# 调用
output_scripted = scripted_model(torch.randn(1, 3, 32, 32))
output_traced = traced_model(torch.randn(1, 3, 32, 32))
# 保存模型
torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.save(traced_model, './model/Test/traced_model.pth')
# 加载模型
load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth')
print(load_scripted_model)
load_traced_model = torch.jit.load('./model/Test/traced_model.pth')
print(load_traced_model)
MyModel(
  (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc): Linear(in_features=4096, out_features=10, bias=True)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)
RecursiveScriptModule(
  original_name=MyModel
  (conv): RecursiveScriptModule(original_name=Conv2d)
  (fc): RecursiveScriptModule(original_name=Linear)
)

说明:

  • RecursiveScriptModule表示一个递归的 TorchScript 模块,类似于一个树形结构。
  • 该模块的原始名称为 MyModel,表示这是一个模型的容器。
  • 该容器包含了两个子模块 convfc,分别是 Conv2dLinear 的递归脚本模块。意味着这两个子模块也是 TorchScript 模块,并可以在 TorchScript 中进行运算。
  • RecursiveScriptModule 可以通过 torch.jit.scripttorch.jit.trace 将 PyTorch 模型转换为 TorchScript 模块。在转换过程中,每个子模块也会被转换为相应的 TorchScript 模块,并嵌套在父模块中。
  • 这种嵌套结构可以很好地表示深度学习模型的层次结构。
  • RecursiveScriptModule中的模块名称和原始名称可以通过original_name属性进行访问。
  • 例如,MyModel 的原始名称是 MyModelconv 模块的原始名称是 Conv2dfc 模块的原始名称是 Linear
相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
目录
相关文章
|
并行计算 PyTorch 算法框架/工具
【pytorch】解决pytorch:Torch not compiled with CUDA enabled
【pytorch】解决pytorch:Torch not compiled with CUDA enabled
4895 0
|
3月前
|
TensorFlow API 算法框架/工具
【Tensorflow】解决Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Te
文章讨论了在使用Tensorflow 2.3时遇到的一个错误:"Inputs to eager execution function cannot be Keras symbolic tensors...",这个问题通常与Tensorflow的eager execution(急切执行)模式有关,提供了三种解决这个问题的方法。
39 1
|
3月前
|
机器学习/深度学习 Linux TensorFlow
【Tensorflow 2】解决tensorflow.python.framework.errors_impl.UnknownError: [_Derived_] Fail to find the...
本文解决了在使用TensorFlow 2.0和Keras API时,尝试使用双向LSTM (tf.keras.layers.Bidirectional) 出现的未知错误问题,并提供了三种解决该问题的方法。
54 3
|
3月前
|
机器学习/深度学习 API TensorFlow
【Tensorflow+keras】解决 Fail to find the dnn implementation.
在TensorFlow 2.0环境中使用双向长短期记忆层(Bidirectional LSTM)遇到“Fail to find the dnn implementation”错误时的三种解决方案。
71 0
|
5月前
|
并行计算 监控 前端开发
函数计算操作报错合集之如何解决报错:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!
在使用函数计算服务(如阿里云函数计算)时,用户可能会遇到多种错误场景。以下是一些常见的操作报错及其可能的原因和解决方法,包括但不限于:1. 函数部署失败、2. 函数执行超时、3. 资源不足错误、4. 权限与访问错误、5. 依赖问题、6. 网络配置错误、7. 触发器配置错误、8. 日志与监控问题。
212 2
|
5月前
|
并行计算 异构计算 Python
python代码torch.device("cuda:0" if torch.cuda.is_available() else "cpu")是什么意思?
【6月更文挑战第3天】python代码torch.device("cuda:0" if torch.cuda.is_available() else "cpu")是什么意思?
557 4
AttributeError: module ‘torch.jit‘ has no attribute ‘_script_if_tracing‘
AttributeError: module ‘torch.jit‘ has no attribute ‘_script_if_tracing‘
288 0
|
TensorFlow 算法框架/工具
成功解决AttributeError: module 'tensorflow.python.keras' has no attribute 'Model'
成功解决AttributeError: module 'tensorflow.python.keras' has no attribute 'Model'
|
PyTorch 算法框架/工具 异构计算
Pytorch出现RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)
这个问题的主要原因是输入的数据类型与网络参数的类型不符。
592 0
|
机器学习/深度学习 PyTorch 算法框架/工具
# Pytorch 中可以直接调用的Loss Functions总结:(二)
# Pytorch 中可以直接调用的Loss Functions总结:(二)
158 0
# Pytorch 中可以直接调用的Loss Functions总结:(二)