MxNet与Caffe模型之间转换的桥梁-Onnx

简介: Open Neural Network Exchange (ONNX)为AI模型提供了一种开源的数据模型格式。它定义了一个可扩展的计算图模型,以及内置运算符和标准数据类型的定义。它可以作为各种AI模型之间进行转换的媒介,例如,市面上没有现成的Caffe模型到MxNet模型的转换工具,我们可以借助于ONNX,首先将Caffe转换为Onnx,然后再将Onnx转换为MxNet,更为神奇的是,这之间的转换过程不过丢失原有模型的精度。

MxNet模型导出ONNX模型


Open Neural Network Exchange (ONNX)为AI模型提供了一种开源的数据模型格式。它定义了一个可扩展的计算图模型,以及内置运算符和标准数据类型的定义。它可以作为各种AI模型之间进行转换的媒介,例如,市面上没有现成的Caffe模型到MxNet模型的转换工具,我们可以借助于ONNX,首先将Caffe转换为Onnx,然后再将Onnx转换为MxNet,更为神奇的是,这之间的转换过程不过丢失原有模型的精度。


在本教程中,我们将展示如何将MXNet模型保存为ONNX格式。MXNet-ONNX操作符的覆盖范围和特性定期更新。请访问 ONNX operator coverage以获取最新信息。在本教程中,我们将学习如何使用MXNet到ONNX的模型导出工具对预先训练的模型进行导出。


预备知识


要运行本教程,你需要安装以下python模块:


  • MXNet >= 1.3.0注意,经测试使用下面的命令进行安装MXNET可用:PIP INSTALL MXNET==1.4.0 --USER


  • onnx注意,经测试使用如下命令进行安装onnx可用:pip install onnx==1.2.1 --user


**注意:**MXNet-ONNX导入、导出工具遵循ONNX操作符集的第7版,该操作符集附带ONNX v1.2.1。


import mxnet as mx
import numpy as np
from mxnet.contrib import onnx as onnx_mxnet
import logging
logging.basicConfig(level=logging.INFO)


从MXNet Model Zoo下载一个模型


我们从MXNet Model Zoo.下载预训练的ResNet-18 ImageNet 模型。我们还将下载synset文件来匹配标签


# Download pre-trained resnet model - json and params by running following code.
path='http://data.mxnet.io/models/imagenet/'
[mx.test_utils.download(path+'resnet/18-layers/resnet-18-0000.params'),
 mx.test_utils.download(path+'resnet/18-layers/resnet-18-symbol.json'),
 mx.test_utils.download(path+'synset.txt')]


现在,我们已经在磁盘上下载了ResNet-18、params和synset文件。


MXNet到ONNX导出器API


让我们来描述MXNet的' export_model ' API。


help(onnx_mxnet.export_model)
Help on function export_model in module mxnet.contrib.onnx.mx2onnx.export_model:
export_model(sym, params, input_shape, input_type=<type 'numpy.float32'>, onnx_file_path=u'model.onnx', verbose=False)
    Exports the MXNet model file, passed as a parameter, into ONNX model.
    Accepts both symbol,parameter objects as well as json and params filepaths as input.
    Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/MXNet-ONNX+Integration
    Parameters
    ----------
    sym : str or symbol object
        Path to the json file or Symbol object
    params : str or symbol object
        Path to the params file or params dictionary. (Including both arg_params and aux_params)
    input_shape : List of tuple
        Input shape of the model e.g [(1,3,224,224)]
    input_type : data type
        Input data type e.g. np.float32
    onnx_file_path : str
        Path where to save the generated onnx file
    verbose : Boolean
        If true will print logs of the model conversion
    Returns
    -------
    onnx_file_path : str
        Onnx file path


' export_model ' API可以通过以下两种方式之一接受MXNet模型。


  1. MXNet sym, params对象:


  • 如果我们正在训练一个模型,这是有用的。在训练结束时,我们只需要调用' export_model '函数,并提供sym和params对象作为输入和其他属性,以将模型保存为ONNX格式。


  1. MXNet导出的json和params文件:


  • 如果我们有预先训练过的模型,并且希望将它们转换为ONNX格式,那么这是非常有用的。


由于我们已经下载了预训练的模型文件,我们将通过传递符号和params文件的路径来使用' export_model ' API。


如何使用MXNet到ONNXA导入、导出工具PI


我们将使用下载的预训练的模型文件(sym、params)并定义输入变量。


# 下载的输入符号和参数文件
sym = './resnet-18-symbol.json'
params = './resnet-18-0000.params'
# 标准Imagenet输入- 3通道,224*224
input_shape = (1,3,224,224)
# 输出文件的路径
onnx_file = './mxnet_exported_resnet50.onnx'


我们已经定义了' export_model ' API所需的输入参数。现在,我们准备将MXNet模型转换为ONNX格式


# 调用导出模型API。它返回转换后的onnx模型的路径
converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)


这个API返回转换后的模型的路径,您稍后可以使用该路径将模型导入其他框架。


检验ONNX模型的有效性


现在我们可以使用ONNX检查工具来检查转换后的ONNX模型的有效性。该工具将通过检查内容是否包含有效的protobuf来验证模型:


from onnx import checker
import onnx
# Load onnx model
model_proto = onnx.load_model(converted_model_path)
# Check if converted ONNX protobuf is valid
checker.check_graph(model_proto.graph)


如果转换后的protobuf格式不符合ONNX proto规范,检查器将抛出错误,但在本例中成功通过。


该方法验证了导出模型原buf的有效性。现在,模型可以导入到其他框架中进行推理了!


相关文章
|
3月前
|
PyTorch 算法框架/工具
Bert PyTorch 源码分析:一、嵌入层
Bert PyTorch 源码分析:一、嵌入层
29 0
|
6天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow中的自定义层与模型
【4月更文挑战第17天】本文介绍了如何在TensorFlow中创建自定义层和模型。自定义层通过继承`tf.keras.layers.Layer`,实现`__init__`, `build`和`call`方法。例如,一个简单的全连接层`CustomDenseLayer`示例展示了如何定义激活函数。自定义模型则继承自`tf.keras.Model`,在`__init__`中定义层,在`call`中实现前向传播。这两个功能使TensorFlow能应对特定需求和复杂网络结构,增强了其在深度学习应用中的灵活性。
|
29天前
|
机器学习/深度学习 人工智能 PyTorch
基于Numpy构建RNN模块并进行实例应用(附代码)
基于Numpy构建RNN模块并进行实例应用(附代码)
33 0
|
4月前
|
机器学习/深度学习 PyTorch TensorFlow
一文带你了解 三种深度学习框架(Caffe,Tensorflow,Pytorch)的基本内容、优缺点以及三者的对比
一文带你了解 三种深度学习框架(Caffe,Tensorflow,Pytorch)的基本内容、优缺点以及三者的对比
149 1
|
人工智能 数据可视化 TensorFlow
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
|
10月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch基础模块一模型(1)
Pytorch基础模块一模型(1)
|
11月前
|
并行计算 PyTorch 测试技术
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-2
由于要进行 tensor 的学习,因此,我们先导入我们需要的库。
|
11月前
|
机器学习/深度学习 人工智能 自然语言处理
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-1
PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebook 的人工智能小组开发,不仅能够实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架如 TensorFlow 都不支持的。
|
分布式计算 并行计算 Hadoop
Tensorflow目标检测接口配合tflite量化模型(二)
Tensorflow目标检测接口配合tflite量化模型
312 0
Tensorflow目标检测接口配合tflite量化模型(二)
|
XML 存储 TensorFlow
Tensorflow目标检测接口配合tflite量化模型(一)
Tensorflow目标检测接口配合tflite量化模型
152 0
Tensorflow目标检测接口配合tflite量化模型(一)

热门文章

最新文章