PyTorch 与 ONNX:模型的跨平台部署策略

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,1000CU*H 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
简介: 【8月更文第27天】深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。

概述

深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。

PyTorch 与 ONNX

PyTorch 是一个非常流行的深度学习框架,它支持动态计算图,非常适合快速原型开发和研究实验。然而,当模型需要部署到生产环境时,就需要考虑模型的兼容性和性能问题。ONNX 提供了一种标准的方式来表示模型,使得模型可以在多种框架和硬件平台上运行。

导出 PyTorch 模型为 ONNX

要将 PyTorch 模型导出为 ONNX 格式,你需要安装 PyTorch 和 onnx 库。接下来是一个简单的示例,展示如何将一个简单的卷积神经网络(CNN)导出为 ONNX 格式。

import torch
import torchvision.models as models
import onnx

# 定义模型
model = models.resnet18(pretrained=True)

# 设置模型为评估模式
model.eval()

# 创建一个示例输入张量
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# 导出模型
torch.onnx.export(model,               # 模型
                  x,                   # 示例输入
                  "resnet18.onnx",     # 输出文件名
                  export_params=True,  # 存储训练过的参数
                  opset_version=10,    # ONNX 版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名字
                  output_names=['output'],  # 输出名字
                  dynamic_axes={
   'input' : {
   0 : 'batch_size'},    # 动态轴
                                'output' : {
   0 : 'batch_size'}})

# 加载导出的 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")

# 检查模型是否正确
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")

ONNX 运行时

ONNX Runtime 是一个高性能的推理引擎,它可以用来在多种平台上运行 ONNX 格式的模型。以下是一个使用 ONNX Runtime 进行推理的示例。

import numpy as np
import onnxruntime

# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("resnet18.onnx")

# 计算 ONNX Runtime 的输出预测
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 输入数据
inputs = {
   "input": to_numpy(x)}

# 计算输出
ort_inputs = {
   ort_session.get_inputs()[0].name: inputs["input"]}
ort_outs = ort_session.run(None, ort_inputs)

# 输出结果
print("ONNX Runtime output:", ort_outs)

跨平台部署

一旦模型被转换为 ONNX 格式,就可以在不同的平台上部署。例如,你可以在 Android 或 iOS 设备上使用 ONNX Runtime for Mobile,或者在嵌入式设备上使用 ONNX Runtime for Edge。

示例:ONNX Runtime for Mobile

如果你的目标平台是移动设备,可以使用 ONNX Runtime for Mobile。下面是一个简单的示例,展示如何在 Android 上部署 ONNX 模型。

  1. 准备 ONNX 模型

    • 将 ONNX 模型文件添加到 Android 项目的 assets 文件夹中。
  2. 编写 Java 代码

    import org.pytorch.IValue;
    import org.pytorch.Module;
    import org.pytorch.Tensor;
    import org.pytorch.torchvision.TensorImageUtils;
    
    public class ModelInference {
         
        private Module module;
    
        public ModelInference(String modelPath) throws Exception {
         
            // 加载模型
            module = Module.load(modelPath);
        }
    
        public float[] infer(float[] input) {
         
            Tensor tensorInput = Tensor.fromBlob(input, new long[]{
         1, 3, 224, 224});
            IValue output = module.forward(IValue.from(tensorInput)).toIValue();
            float[] outputData = output.toTensor().toFloatArray();
            return outputData;
        }
    }
    
  3. 使用模型

    public class MainActivity extends AppCompatActivity {
         
        private ModelInference model;
    
        @Override
        protected void onCreate(Bundle savedInstanceState) {
         
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_main);
    
            try {
         
                model = new ModelInference(getAssets().openFd("resnet18.onnx").getName());
            } catch (Exception e) {
         
                e.printStackTrace();
            }
    
            // 准备输入数据
            float[] input = TensorImageUtils.bitmapToFloatArray(
                    BitmapFactory.decodeResource(getResources(), R.drawable.input_image), false, false);
    
            // 进行推理
            float[] result = model.infer(input);
            Log.d("Inference", Arrays.toString(result));
        }
    }
    

结论

通过使用 PyTorch 与 ONNX,你可以轻松地将训练好的模型部署到各种不同的平台上。这种方式不仅可以提高模型的可移植性,还可以充分利用不同平台上的硬件加速功能,从而提高性能。无论是移动设备、嵌入式系统还是云端服务器,ONNX 都能够帮助你实现高效的模型部署。

目录
相关文章
|
2月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
183 1
|
6月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
961 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
2月前
|
机器学习/深度学习 算法 安全
近端策略优化算法PPO的核心概念和PyTorch实现详解
近端策略优化(PPO)是强化学习中的关键算法,因其在复杂任务中的稳定表现而广泛应用。本文详解PPO核心原理,并提供基于PyTorch的完整实现方案,涵盖环境交互、优势计算与策略更新裁剪机制。通过Lunar Lander环境演示训练流程,帮助读者掌握算法精髓。
363 54
|
1月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
146 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
8月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
677 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
3月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
246 9
|
3月前
|
机器学习/深度学习 算法 数据可视化
近端策略优化算法PPO的核心概念和PyTorch实现详解
本文深入解析了近端策略优化(PPO)算法的核心原理,并基于PyTorch框架实现了完整的强化学习训练流程。通过Lunar Lander环境展示了算法的全过程,涵盖环境交互、优势函数计算、策略更新等关键模块。内容理论与实践结合,适合希望掌握PPO算法及其实现的读者。
594 2
近端策略优化算法PPO的核心概念和PyTorch实现详解
|
5月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
252 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
4月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1817 0
Flow Matching生成模型:从理论基础到Pytorch代码实现

推荐镜像

更多