【tensorflow】TF1.x保存与读取.pb模型写法介绍

简介: 由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。

由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。


举例:模型定义如下


# 定义模型
with tf.name_scope("Model"):
    """MLP"""
    # 13个连续特征数据(13列)
    x = tf.placeholder(tf.float32, [None,13], name='X') 
    # 正则化
    x_norm = tf.layers.batch_normalization(inputs=x)
    # 定义一层Dense
    dense_1 = tf.layers.Dense(64, activation="relu")(x_norm)
    """EMBED"""
    # 离散输入
    y = tf.placeholder(tf.int32, [None,2], name='Y')
    # 创建嵌入矩阵变量
    embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))
    # 使用tf.nn.embedding_lookup函数获取嵌入向量
    embeddings = tf.nn.embedding_lookup(embedding_matrix, y)
    # 创建 LSTM 层
    lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
    # 初始化 LSTM 单元状态
    initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)
    # 将输入数据传递给 LSTM 层
    lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)
    # 定义一层Dense
    dense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])
    """MERGE"""
    combined = tf.concat([dense_1, dense_2], axis = -1)
    pred = tf.layers.Dense(2, activation="relu")(combined)
    pred = tf.layers.Dense(1, activation="linear", name='P')(pred)
    z = tf.placeholder(tf.float32, [None, 1], name='Z')


  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。


模型保存代码


import tensorflow as tf
from tensorflow import saved_model as sm
# 创建 Saver 对象
saver = tf.train.Saver()
# 生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    sess.run(tf.global_variables_initializer())
    ...... # 模型训练逻辑
    # 准备存储模型
    path = 'pb_model/'
    dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    pb_saver = tf.train.Saver(dense_model_var)
    builder = sm.builder.SavedModelBuilder(path)
    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    # 自定义 根据自己的模型来写
    X = sm.utils.build_tensor_info(x)
    Y = sm.utils.build_tensor_info(y)
    Z = sm.utils.build_tensor_info(z)
    P = sm.utils.build_tensor_info(pred)
    # 构建 SignatureDef protobuf
    # inputs outputs 自定义 根据自己的模型来写
    SignatureDef = sm.signature_def_utils.build_signature_def(
                                inputs={'X': X, 'Y': Y, 'Z': Z},  # 可用sm.signature_constants.PREDICT_INPUTS
                                outputs={'P': P},  # 可用sm.signature_constants.PREDICT_OUTPUTS
                                method_name="tensorflow/serving/predict"
    )
    # 将 graph 和变量等信息写入 MetaGraphDef protobuf
    # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一
    builder.add_meta_graph_and_variables(sess, tags=['serve'],
                                             signature_def_map={
                                                 sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},
                                             saver=pb_saver,
                                             main_op=tf.local_variables_initializer())
    # 将 MetaGraphDef 写入磁盘
    builder.save()


 最重要的是这一句:dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),意思是保存当前作用域下的所有可训练的变量。


 我之前写的是dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name_scope="Model"),这样读不了所有的可训练变量,只能读到embedding_matrix 一个,虽然也能保存模型,但是没保存模型的其他变量值,就会出错。


模型加载代码


import tensorflow as tf
from tensorflow import saved_model as sm
tf.reset_default_graph()
# 创建一个新的默认图
graph = tf.Graph()
# 需要建立一个会话对象,将模型恢复到其中
with tf.Session(graph=graph) as sess:
    path = 'pb_model/'
    MetaGraphDef = sm.loader.load(sess, tags=['serve'], export_dir=path)
    # 解析得到 SignatureDef protobuf
    SignatureDef_map = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_map[sm.signature_constants.PREDICT_METHOD_NAME]
    # 解析得到 3 个变量对应的 TensorInfo protobuf
    X = SignatureDef.inputs['X']
    Y = SignatureDef.inputs['Y']
    Z = SignatureDef.inputs['Z']
    P = SignatureDef.outputs['P']
    # 解析得到具体 Tensor
    # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
    # x = sm.utils.get_tensor_from_tensor_info(X)
    # y = sm.utils.get_tensor_from_tensor_info(Y)
    # z = sm.utils.get_tensor_from_tensor_info(Z)
    x = sess.graph.get_tensor_by_name(X.name)
    y = sess.graph.get_tensor_by_name(Y.name)
    z = sess.graph.get_tensor_by_name(Z.name)
    p = sess.graph.get_tensor_by_name(P.name)
    # 这里就可以开始进行预测或者继续训练了 TODO
    total_loss = sess.run(loss_function, feed_dict={x: dense_ch_val, y: sparse_ch_val, z: score_val})
    print(total_loss)



相关文章
|
4月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
65 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
4月前
|
机器学习/深度学习 监控 Python
tensorflow2.x多层感知机模型参数量和计算量的统计
tensorflow2.x多层感知机模型参数量和计算量的统计
|
7月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【tensorflow】连续输入的线性回归模型训练代码
  get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。
|
7月前
|
机器学习/深度学习 算法 TensorFlow
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
134 1
|
7月前
|
机器学习/深度学习 移动开发 算法
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
92 0
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
|
4月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
197 0
|
2天前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
|
4天前
|
机器学习/深度学习 TensorFlow API
Python安装TensorFlow 2、tf.keras和深度学习模型的定义
Python安装TensorFlow 2、tf.keras和深度学习模型的定义
|
14天前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph

热门文章

最新文章