概述
深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。
PyTorch 与 ONNX
PyTorch 是一个非常流行的深度学习框架,它支持动态计算图,非常适合快速原型开发和研究实验。然而,当模型需要部署到生产环境时,就需要考虑模型的兼容性和性能问题。ONNX 提供了一种标准的方式来表示模型,使得模型可以在多种框架和硬件平台上运行。
导出 PyTorch 模型为 ONNX
要将 PyTorch 模型导出为 ONNX 格式,你需要安装 PyTorch 和 onnx 库。接下来是一个简单的示例,展示如何将一个简单的卷积神经网络(CNN)导出为 ONNX 格式。
import torch
import torchvision.models as models
import onnx
# 定义模型
model = models.resnet18(pretrained=True)
# 设置模型为评估模式
model.eval()
# 创建一个示例输入张量
x = torch.randn(1, 3, 224, 224, requires_grad=True)
# 导出模型
torch.onnx.export(model, # 模型
x, # 示例输入
"resnet18.onnx", # 输出文件名
export_params=True, # 存储训练过的参数
opset_version=10, # ONNX 版本
do_constant_folding=True, # 是否执行常量折叠优化
input_names=['input'], # 输入名字
output_names=['output'], # 输出名字
dynamic_axes={
'input' : {
0 : 'batch_size'}, # 动态轴
'output' : {
0 : 'batch_size'}})
# 加载导出的 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
# 检查模型是否正确
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")
ONNX 运行时
ONNX Runtime 是一个高性能的推理引擎,它可以用来在多种平台上运行 ONNX 格式的模型。以下是一个使用 ONNX Runtime 进行推理的示例。
import numpy as np
import onnxruntime
# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("resnet18.onnx")
# 计算 ONNX Runtime 的输出预测
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# 输入数据
inputs = {
"input": to_numpy(x)}
# 计算输出
ort_inputs = {
ort_session.get_inputs()[0].name: inputs["input"]}
ort_outs = ort_session.run(None, ort_inputs)
# 输出结果
print("ONNX Runtime output:", ort_outs)
跨平台部署
一旦模型被转换为 ONNX 格式,就可以在不同的平台上部署。例如,你可以在 Android 或 iOS 设备上使用 ONNX Runtime for Mobile,或者在嵌入式设备上使用 ONNX Runtime for Edge。
示例:ONNX Runtime for Mobile
如果你的目标平台是移动设备,可以使用 ONNX Runtime for Mobile。下面是一个简单的示例,展示如何在 Android 上部署 ONNX 模型。
准备 ONNX 模型:
- 将 ONNX 模型文件添加到 Android 项目的 assets 文件夹中。
编写 Java 代码:
import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; public class ModelInference { private Module module; public ModelInference(String modelPath) throws Exception { // 加载模型 module = Module.load(modelPath); } public float[] infer(float[] input) { Tensor tensorInput = Tensor.fromBlob(input, new long[]{ 1, 3, 224, 224}); IValue output = module.forward(IValue.from(tensorInput)).toIValue(); float[] outputData = output.toTensor().toFloatArray(); return outputData; } }
使用模型:
public class MainActivity extends AppCompatActivity { private ModelInference model; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); try { model = new ModelInference(getAssets().openFd("resnet18.onnx").getName()); } catch (Exception e) { e.printStackTrace(); } // 准备输入数据 float[] input = TensorImageUtils.bitmapToFloatArray( BitmapFactory.decodeResource(getResources(), R.drawable.input_image), false, false); // 进行推理 float[] result = model.infer(input); Log.d("Inference", Arrays.toString(result)); } }
结论
通过使用 PyTorch 与 ONNX,你可以轻松地将训练好的模型部署到各种不同的平台上。这种方式不仅可以提高模型的可移植性,还可以充分利用不同平台上的硬件加速功能,从而提高性能。无论是移动设备、嵌入式系统还是云端服务器,ONNX 都能够帮助你实现高效的模型部署。