PyTorch 与 ONNX:模型的跨平台部署策略

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 【8月更文第27天】深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。

概述

深度学习模型的训练通常是在具有强大计算能力的平台上完成的,比如配备有高性能 GPU 的服务器。然而,为了将这些模型应用到实际产品中,往往需要将其部署到各种不同的设备上,包括移动设备、边缘计算设备甚至是嵌入式系统。这就需要一种能够在多种平台上运行的模型格式。ONNX(Open Neural Network Exchange)作为一种开放的标准,旨在解决模型的可移植性问题,使得开发者可以在不同的框架之间无缝迁移模型。本文将介绍如何使用 PyTorch 将训练好的模型导出为 ONNX 格式,并进一步探讨如何在不同平台上部署这些模型。

PyTorch 与 ONNX

PyTorch 是一个非常流行的深度学习框架,它支持动态计算图,非常适合快速原型开发和研究实验。然而,当模型需要部署到生产环境时,就需要考虑模型的兼容性和性能问题。ONNX 提供了一种标准的方式来表示模型,使得模型可以在多种框架和硬件平台上运行。

导出 PyTorch 模型为 ONNX

要将 PyTorch 模型导出为 ONNX 格式,你需要安装 PyTorch 和 onnx 库。接下来是一个简单的示例,展示如何将一个简单的卷积神经网络(CNN)导出为 ONNX 格式。

import torch
import torchvision.models as models
import onnx

# 定义模型
model = models.resnet18(pretrained=True)

# 设置模型为评估模式
model.eval()

# 创建一个示例输入张量
x = torch.randn(1, 3, 224, 224, requires_grad=True)

# 导出模型
torch.onnx.export(model,               # 模型
                  x,                   # 示例输入
                  "resnet18.onnx",     # 输出文件名
                  export_params=True,  # 存储训练过的参数
                  opset_version=10,    # ONNX 版本
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名字
                  output_names=['output'],  # 输出名字
                  dynamic_axes={
   'input' : {
   0 : 'batch_size'},    # 动态轴
                                'output' : {
   0 : 'batch_size'}})

# 加载导出的 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")

# 检查模型是否正确
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")

ONNX 运行时

ONNX Runtime 是一个高性能的推理引擎,它可以用来在多种平台上运行 ONNX 格式的模型。以下是一个使用 ONNX Runtime 进行推理的示例。

import numpy as np
import onnxruntime

# 加载 ONNX 模型
ort_session = onnxruntime.InferenceSession("resnet18.onnx")

# 计算 ONNX Runtime 的输出预测
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 输入数据
inputs = {
   "input": to_numpy(x)}

# 计算输出
ort_inputs = {
   ort_session.get_inputs()[0].name: inputs["input"]}
ort_outs = ort_session.run(None, ort_inputs)

# 输出结果
print("ONNX Runtime output:", ort_outs)

跨平台部署

一旦模型被转换为 ONNX 格式,就可以在不同的平台上部署。例如,你可以在 Android 或 iOS 设备上使用 ONNX Runtime for Mobile,或者在嵌入式设备上使用 ONNX Runtime for Edge。

示例:ONNX Runtime for Mobile

如果你的目标平台是移动设备,可以使用 ONNX Runtime for Mobile。下面是一个简单的示例,展示如何在 Android 上部署 ONNX 模型。

  1. 准备 ONNX 模型

    • 将 ONNX 模型文件添加到 Android 项目的 assets 文件夹中。
  2. 编写 Java 代码

    import org.pytorch.IValue;
    import org.pytorch.Module;
    import org.pytorch.Tensor;
    import org.pytorch.torchvision.TensorImageUtils;
    
    public class ModelInference {
         
        private Module module;
    
        public ModelInference(String modelPath) throws Exception {
         
            // 加载模型
            module = Module.load(modelPath);
        }
    
        public float[] infer(float[] input) {
         
            Tensor tensorInput = Tensor.fromBlob(input, new long[]{
         1, 3, 224, 224});
            IValue output = module.forward(IValue.from(tensorInput)).toIValue();
            float[] outputData = output.toTensor().toFloatArray();
            return outputData;
        }
    }
    
  3. 使用模型

    public class MainActivity extends AppCompatActivity {
         
        private ModelInference model;
    
        @Override
        protected void onCreate(Bundle savedInstanceState) {
         
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_main);
    
            try {
         
                model = new ModelInference(getAssets().openFd("resnet18.onnx").getName());
            } catch (Exception e) {
         
                e.printStackTrace();
            }
    
            // 准备输入数据
            float[] input = TensorImageUtils.bitmapToFloatArray(
                    BitmapFactory.decodeResource(getResources(), R.drawable.input_image), false, false);
    
            // 进行推理
            float[] result = model.infer(input);
            Log.d("Inference", Arrays.toString(result));
        }
    }
    

结论

通过使用 PyTorch 与 ONNX,你可以轻松地将训练好的模型部署到各种不同的平台上。这种方式不仅可以提高模型的可移植性,还可以充分利用不同平台上的硬件加速功能,从而提高性能。无论是移动设备、嵌入式系统还是云端服务器,ONNX 都能够帮助你实现高效的模型部署。

目录
相关文章
|
10天前
|
机器学习/深度学习 存储 PyTorch
PyTorch内存优化的10种策略总结:在有限资源环境下高效训练模型
在大规模深度学习模型训练中,GPU内存容量常成为瓶颈,特别是在训练大型语言模型和视觉Transformer时。本文系统介绍了多种内存优化策略,包括混合精度训练、低精度训练(如BF16)、梯度检查点、梯度累积、张量分片与分布式训练、
49 14
PyTorch内存优化的10种策略总结:在有限资源环境下高效训练模型
|
2月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
220 66
|
15天前
|
机器学习/深度学习 算法 安全
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
145 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
|
27天前
|
机器学习/深度学习 存储 算法
近端策略优化(PPO)算法的理论基础与PyTorch代码详解
近端策略优化(PPO)是深度强化学习中高效的策略优化方法,广泛应用于大语言模型的RLHF训练。PPO通过引入策略更新约束机制,平衡了更新幅度,提升了训练稳定性。其核心思想是在优势演员-评论家方法的基础上,采用裁剪和非裁剪项组成的替代目标函数,限制策略比率在[1-ϵ, 1+ϵ]区间内,防止过大的策略更新。本文详细探讨了PPO的基本原理、损失函数设计及PyTorch实现流程,提供了完整的代码示例。
235 10
近端策略优化(PPO)算法的理论基础与PyTorch代码详解
|
5月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
710 2
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
113 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
7月前
|
机器学习/深度学习 并行计算 PyTorch
优化技巧与策略:提高 PyTorch 模型训练效率
【8月更文第29天】在深度学习领域中,PyTorch 是一个非常流行的框架,被广泛应用于各种机器学习任务中。然而,随着模型复杂度的增加以及数据集规模的增长,如何有效地训练这些模型成为了一个重要的问题。本文将介绍一系列优化技巧和策略,帮助提高 PyTorch 模型训练的效率。
647 0
|
4月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
91 5
|
5月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
137 7
利用 PyTorch Lightning 搭建一个文本分类模型
|
5月前
|
PyTorch TensorFlow 算法框架/工具
Jetson环境安装(一):Ubuntu18.04安装pytorch、opencv、onnx、tensorflow、setuptools、pycuda....
本文提供了在Ubuntu 18.04操作系统的NVIDIA Jetson平台上安装深度学习和计算机视觉相关库的详细步骤,包括PyTorch、OpenCV、ONNX、TensorFlow等。
326 1
Jetson环境安装(一):Ubuntu18.04安装pytorch、opencv、onnx、tensorflow、setuptools、pycuda....