torch.jit.script 与 torch.jit.trace
torch.jit.script
和 torch.jit.trace
是 PyTorch 中用于将模型转换为脚本或跟踪模型执行的工具。
它们是 PyTorch 的即时编译(Just-in-Time Compilation)模块的一部分,用于提高模型的执行效率并支持模型的部署。
torch.jit.script
torch.jit.script
是将模型转换为脚本的函数。
它接受一个 PyTorch 模型作为输入,并将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。
这种转换的好处是可以减少模型执行过程中的开销,因为它消除了 Python 解释器的开销。
示例:
import torch # 定义模型 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1) self.fc = torch.nn.Linear(64 * 8 * 8, 10) def forward(self, x): x = self.conv(x) x = torch.nn.functional.relu(x) x = x.view(-1, 64 * 8 * 8) x = self.fc(x) return x model = MyModel() # 将模型转换为Torch脚本模块 scripted_model = torch.jit.script(model) # 调用 output = scripted_model(torch.randn(1, 3, 32, 32)) print(output) # 保存模型 torch.jit.save(scripted_model, './model/Test/scripted_model.pth')
torch.jit.trace
torch.jit.trace
是跟踪模型执行的函数。
它接受一个模型和一个示例输入,并记录模型在给定输入上的执行过程,然后返回一个跟踪模型。
跟踪模型可以看作是一个具有相同功能的脚本模型,但它还保留了原始模型的动态特性,可以使用更多高级特性,如动态图和控制流。
示例:
import torch # 定义模型 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1) self.fc = torch.nn.Linear(64 * 8 * 8, 10) def forward(self, x): x = self.conv(x) x = torch.nn.functional.relu(x) x = x.view(-1, 64 * 8 * 8) x = self.fc(x) return x model = MyModel() # 将模型转换为Torch脚本模块 traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32)) # 调用 output = traced_model(torch.randn(1, 3, 32, 32)) print(output) # 保存模型 torch.jit.save(traced_model, './model/Test/traced_model.pth')
注意
由于
torch.jit.trace
方法只跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。
torch.jit.save
使用 torch.jit.script
或 torch.jit.trace
转换后的模块对象可以直接用于推理,也可以使用 torch.jit.save
方法将其保存到文件中,以便在需要时加载模型。
torch.jit.load
使用 torch.jit.load
函数可以加载 PyTorch 模型,该函数可以接收一个模型文件路径或一个文件对象作为输入参数。具体步骤如下:
- 加载模型文件:
import torch model = torch.jit.load("model.pt")
这将加载名为
model.pt
的模型文件。
- 加载模型文件并指定设备:
import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.jit.load("model.pt", map_location=device)
这将加载名为
model.pt
的模型文件,并将其放置在可用的 CUDA 设备上。
- 加载模型文件并使用
eval
模式:
import torch model = torch.jit.load("model.pt") model.eval()
这将加载名为
model.pt
的模型文件,并将其转换为评估模式。
注意:
如果模型使用了特定的设备,例如 CUDA,那么在加载模型时需要确保该设备可用。如果设备不可用,则需要使用 map_location
参数将模型映射到可用的设备上。
Code
import torch # 定义模型 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1) self.fc = torch.nn.Linear(64 * 8 * 8, 10) def forward(self, x): x = self.conv(x) x = torch.nn.functional.relu(x) x = x.view(-1, 64 * 8 * 8) x = self.fc(x) return x model = MyModel() print(model) # 将模型转换为Torch脚本模块 scripted_model = torch.jit.script(model) traced_model = torch.jit.trace(model, torch.randn(1, 3, 32, 32)) # 调用 output_scripted = scripted_model(torch.randn(1, 3, 32, 32)) output_traced = traced_model(torch.randn(1, 3, 32, 32)) # 保存模型 torch.jit.save(scripted_model, './model/Test/scripted_model.pth') torch.jit.save(traced_model, './model/Test/traced_model.pth') # 加载模型 load_scripted_model = torch.jit.load('./model/Test/scripted_model.pth') print(load_scripted_model) load_traced_model = torch.jit.load('./model/Test/traced_model.pth') print(load_traced_model)
MyModel( (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (fc): Linear(in_features=4096, out_features=10, bias=True) ) RecursiveScriptModule( original_name=MyModel (conv): RecursiveScriptModule(original_name=Conv2d) (fc): RecursiveScriptModule(original_name=Linear) ) RecursiveScriptModule( original_name=MyModel (conv): RecursiveScriptModule(original_name=Conv2d) (fc): RecursiveScriptModule(original_name=Linear) )
说明:
RecursiveScriptModule
表示一个递归的 TorchScript 模块,类似于一个树形结构。
- 该模块的原始名称为
MyModel
,表示这是一个模型的容器。
- 该容器包含了两个子模块
conv
和fc
,分别是Conv2d
和Linear
的递归脚本模块。意味着这两个子模块也是 TorchScript 模块,并可以在 TorchScript 中进行运算。 RecursiveScriptModule
可以通过torch.jit.script
或torch.jit.trace
将 PyTorch 模型转换为 TorchScript 模块。在转换过程中,每个子模块也会被转换为相应的 TorchScript 模块,并嵌套在父模块中。- 这种嵌套结构可以很好地表示深度学习模型的层次结构。
RecursiveScriptModule
中的模块名称和原始名称可以通过original_name
属性进行访问。
- 例如,
MyModel
的原始名称是MyModel
,conv
模块的原始名称是Conv2d
,fc
模块的原始名称是Linear
。