【tensorflow】TF1.x保存与读取.pb模型写法介绍

简介: 由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。

由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。


举例:模型定义如下


# 定义模型
with tf.name_scope("Model"):
    """MLP"""
    # 13个连续特征数据(13列)
    x = tf.placeholder(tf.float32, [None,13], name='X') 
    # 正则化
    x_norm = tf.layers.batch_normalization(inputs=x)
    # 定义一层Dense
    dense_1 = tf.layers.Dense(64, activation="relu")(x_norm)
    """EMBED"""
    # 离散输入
    y = tf.placeholder(tf.int32, [None,2], name='Y')
    # 创建嵌入矩阵变量
    embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))
    # 使用tf.nn.embedding_lookup函数获取嵌入向量
    embeddings = tf.nn.embedding_lookup(embedding_matrix, y)
    # 创建 LSTM 层
    lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
    # 初始化 LSTM 单元状态
    initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)
    # 将输入数据传递给 LSTM 层
    lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)
    # 定义一层Dense
    dense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])
    """MERGE"""
    combined = tf.concat([dense_1, dense_2], axis = -1)
    pred = tf.layers.Dense(2, activation="relu")(combined)
    pred = tf.layers.Dense(1, activation="linear", name='P')(pred)
    z = tf.placeholder(tf.float32, [None, 1], name='Z')


  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。


模型保存代码


import tensorflow as tf
from tensorflow import saved_model as sm
# 创建 Saver 对象
saver = tf.train.Saver()
# 生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    sess.run(tf.global_variables_initializer())
    ...... # 模型训练逻辑
    # 准备存储模型
    path = 'pb_model/'
    dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    pb_saver = tf.train.Saver(dense_model_var)
    builder = sm.builder.SavedModelBuilder(path)
    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    # 自定义 根据自己的模型来写
    X = sm.utils.build_tensor_info(x)
    Y = sm.utils.build_tensor_info(y)
    Z = sm.utils.build_tensor_info(z)
    P = sm.utils.build_tensor_info(pred)
    # 构建 SignatureDef protobuf
    # inputs outputs 自定义 根据自己的模型来写
    SignatureDef = sm.signature_def_utils.build_signature_def(
                                inputs={'X': X, 'Y': Y, 'Z': Z},  # 可用sm.signature_constants.PREDICT_INPUTS
                                outputs={'P': P},  # 可用sm.signature_constants.PREDICT_OUTPUTS
                                method_name="tensorflow/serving/predict"
    )
    # 将 graph 和变量等信息写入 MetaGraphDef protobuf
    # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一
    builder.add_meta_graph_and_variables(sess, tags=['serve'],
                                             signature_def_map={
                                                 sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},
                                             saver=pb_saver,
                                             main_op=tf.local_variables_initializer())
    # 将 MetaGraphDef 写入磁盘
    builder.save()


 最重要的是这一句:dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),意思是保存当前作用域下的所有可训练的变量。


 我之前写的是dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name_scope="Model"),这样读不了所有的可训练变量,只能读到embedding_matrix 一个,虽然也能保存模型,但是没保存模型的其他变量值,就会出错。


模型加载代码


import tensorflow as tf
from tensorflow import saved_model as sm
tf.reset_default_graph()
# 创建一个新的默认图
graph = tf.Graph()
# 需要建立一个会话对象,将模型恢复到其中
with tf.Session(graph=graph) as sess:
    path = 'pb_model/'
    MetaGraphDef = sm.loader.load(sess, tags=['serve'], export_dir=path)
    # 解析得到 SignatureDef protobuf
    SignatureDef_map = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_map[sm.signature_constants.PREDICT_METHOD_NAME]
    # 解析得到 3 个变量对应的 TensorInfo protobuf
    X = SignatureDef.inputs['X']
    Y = SignatureDef.inputs['Y']
    Z = SignatureDef.inputs['Z']
    P = SignatureDef.outputs['P']
    # 解析得到具体 Tensor
    # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
    # x = sm.utils.get_tensor_from_tensor_info(X)
    # y = sm.utils.get_tensor_from_tensor_info(Y)
    # z = sm.utils.get_tensor_from_tensor_info(Z)
    x = sess.graph.get_tensor_by_name(X.name)
    y = sess.graph.get_tensor_by_name(Y.name)
    z = sess.graph.get_tensor_by_name(Z.name)
    p = sess.graph.get_tensor_by_name(P.name)
    # 这里就可以开始进行预测或者继续训练了 TODO
    total_loss = sess.run(loss_function, feed_dict={x: dense_ch_val, y: sparse_ch_val, z: score_val})
    print(total_loss)



相关文章
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
83 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
5天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
21 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
5天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
25 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
22天前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
65 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
108 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
1月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
76 0
|
3月前
|
机器学习/深度学习 存储 前端开发
实战揭秘:如何借助TensorFlow.js的强大力量,轻松将高效能的机器学习模型无缝集成到Web浏览器中,从而打造智能化的前端应用并优化用户体验
【8月更文挑战第31天】将机器学习模型集成到Web应用中,可让用户在浏览器内体验智能化功能。TensorFlow.js作为在客户端浏览器中运行的库,提供了强大支持。本文通过问答形式详细介绍如何使用TensorFlow.js将机器学习模型带入Web浏览器,并通过具体示例代码展示最佳实践。首先,需在HTML文件中引入TensorFlow.js库;接着,可通过加载预训练模型如MobileNet实现图像分类;然后,编写代码处理图像识别并显示结果;此外,还介绍了如何训练自定义模型及优化模型性能的方法,包括模型量化、剪枝和压缩等。
51 1
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
77 1
|
3月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
80 0
|
3月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
84 0