【tensorflow】TF1.x保存.pb模型 解决模型越训练越大问题

简介: 在上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法中,保存的是模型训练过程中所有的参数,而且训练越久,最终保存的模型就越大。我的模型只有几千参数,可是最终保存的文件有1GB。。。。

 在上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法中,保存的是模型训练过程中所有的参数,而且训练越久,最终保存的模型就越大。我的模型只有几千参数,可是最终保存的文件有1GB。。。。


 但是其实我只想要保存参数去部署模型,然后预测。网上有一些解决方案但都不是我需要的,因为我要用Java部署模型,python这里必须要用builder.add_meta_graph_and_variables来保存参数。以下是解决方案:


举例:模型定义如下


# 定义模型
with tf.name_scope("Model"):
    """MLP"""
    # 13个连续特征数据(13列)
    x = tf.placeholder(tf.float32, [None,13], name='X') 
    # 正则化
    x_norm = tf.layers.batch_normalization(inputs=x)
    # 定义一层Dense
    dense_1 = tf.layers.Dense(64, activation="relu")(x_norm)
    """EMBED"""
    # 离散输入
    y = tf.placeholder(tf.int32, [None,2], name='Y')
    # 创建嵌入矩阵变量
    embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))
    # 使用tf.nn.embedding_lookup函数获取嵌入向量
    embeddings = tf.nn.embedding_lookup(embedding_matrix, y)
    # 创建 LSTM 层
    lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
    # 初始化 LSTM 单元状态
    initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)
    # 将输入数据传递给 LSTM 层
    lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)
    # 定义一层Dense
    dense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])
    """MERGE"""
    combined = tf.concat([dense_1, dense_2], axis = -1)
    pred = tf.layers.Dense(2, activation="relu")(combined)
    pred = tf.layers.Dense(1, activation="linear", name='P')(pred)
    z = tf.placeholder(tf.float32, [None, 1], name='Z')


  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。


模型保存代码


  这里的保存方法建议对比上一篇博客【tensorflow】TF1.x保存与读取.pb模型写法介绍介绍的保存.pb模型方法来看。


import tensorflow as tf
from tensorflow import saved_model as sm
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.core.framework import graph_pb2
def get_node_names(name_list, nodes_list):
    name_list.extend([n.name.split(":")[0] for _, n in nodes_list.items() if n.name.split(":")[0] != ''])
# 创建 Saver 对象
saver = tf.train.Saver()
# 生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    sess.run(tf.global_variables_initializer())
    ...... # 模型训练逻辑
   # 准备存储模型
    path = 'pb_model/'
    # 创建 Saver 对象,用于保存和加载模型的变量
    pb_saver = tf.train.Saver(var_list=None)
    # 将 Saver 对象转换为 SaverDef 对象
    saver_def = pb_saver.as_saver_def()
    # 从会话的图定义中提取包含恢复操作的子图
    saver_def_ingraph = tf.graph_util.extract_sub_graph(sess.graph.as_graph_def(), [saver_def.restore_op_name])
    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    # 自定义 根据自己的模型来写
    inputs = {
        'x' : sm.utils.build_tensor_info(x),
        'y' : sm.utils.build_tensor_info(y),
        'z' : sm.utils.build_tensor_info(z)
    }
    outputs = {
        'p' : sm.utils.build_tensor_info(pred)
    }
    # 获取节点的名称
    input_node_names = []
    get_node_names(input_node_names, inputs)
    output_node_names = []
    get_node_names(output_node_names, outputs)
    # 获取当前会话的图定义
    input_graph_def = sess.graph.as_graph_def()
    # 定义需要应用的图转换操作列表
    transforms = ['add_default_attributes',
                  'fold_constants(ignore_errors=true)',
                  'fold_batch_norms',
                  'fold_old_batch_norms',
                  'sort_by_execution_order',
                  'strip_unused_nodes']
    # 应用图转换操作,并获取优化后的图定义
    opt_graph_def = TransformGraph(input_graph_def,
                                   input_node_names,
                                   output_node_names,
                                   transforms)
    # 创建新的默认图并导入优化后的图定义
    with tf.Graph().as_default() as graph:
        all_names = set([node.name for node in opt_graph_def.node])
        saver_def_ingraph_nodes = [node for node in saver_def_ingraph.node if not node.name in all_names]
        merged_graph_def = graph_pb2.GraphDef()
        merged_graph_def.node.extend(opt_graph_def.node)
        merged_graph_def.node.extend(saver_def_ingraph_nodes)
        # 导入合并后的图定义到新的默认图中
        tf.graph_util.import_graph_def(merged_graph_def, name="")
        builder = sm.builder.SavedModelBuilder(path)
        # 将 graph 和变量等信息写入 MetaGraphDef protobuf
        # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一
        builder.add_meta_graph_and_variables(
            sess, tags=[sm.tag_constants.SERVING],
            signature_def_map={sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},
            saver=pb_saver,
            main_op=tf.local_variables_initializer()
        )
        # 将 MetaGraphDef 写入磁盘
        builder.save()


  这样之后你会发现模型的大小从GB锐减到几十KB。


相关文章
|
11天前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
143 59
|
1月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
关于Tensorflow!目标检测预训练模型的迁移学习
这篇文章主要介绍了使用Tensorflow进行目标检测的迁移学习过程。关于使用Tensorflow进行目标检测模型训练的实战教程,涵盖了从数据准备到模型应用的全过程,特别适合对此领域感兴趣的开发者参考。
35 3
关于Tensorflow!目标检测预训练模型的迁移学习
|
19天前
|
机器学习/深度学习 算法 TensorFlow
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
谷物识别系统,本系统使用Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经算法网络模型,通过对11种谷物图片数据集('大米', '小米', '燕麦', '玉米渣', '红豆', '绿豆', '花生仁', '荞麦', '黄豆', '黑米', '黑豆')进行训练,得到一个进度较高的H5格式的模型文件。然后使用Django框架搭建了一个Web网页端可视化操作界面。实现用户上传一张图片识别其名称。
55 0
【图像识别】谷物识别系统Python+人工智能深度学习+TensorFlow+卷积算法网络模型+图像识别
|
1月前
|
机器学习/深度学习 人工智能 算法
食物识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
食物识别系统采用TensorFlow的ResNet50模型,训练了包含11类食物的数据集,生成高精度H5模型。系统整合Django框架,提供网页平台,用户可上传图片进行食物识别。效果图片展示成功识别各类食物。[查看演示视频、代码及安装指南](https://www.yuque.com/ziwu/yygu3z/yhd6a7vai4o9iuys?singleDoc#)。项目利用深度学习的卷积神经网络(CNN),其局部感受野和权重共享机制适于图像识别,广泛应用于医疗图像分析等领域。示例代码展示了一个使用TensorFlow训练的简单CNN模型,用于MNIST手写数字识别。
59 3
|
1月前
|
机器学习/深度学习 TensorFlow API
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
80 2
|
23天前
|
机器学习/深度学习 人工智能 算法
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
中草药识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
62 0
|
25天前
|
机器学习/深度学习 自然语言处理 TensorFlow
构建高效的机器学习模型:基于Python和TensorFlow的实践
构建高效的机器学习模型:基于Python和TensorFlow的实践
39 0
|
1月前
|
机器学习/深度学习 数据可视化 TensorFlow
【Python 机器学习专栏】使用 TensorFlow 构建深度学习模型
【4月更文挑战第30天】本文介绍了如何使用 TensorFlow 构建深度学习模型。TensorFlow 是谷歌的开源深度学习框架,具备强大计算能力和灵活编程接口。构建模型涉及数据准备、模型定义、选择损失函数和优化器、训练、评估及模型保存部署。文中以全连接神经网络为例,展示了从数据预处理到模型训练和评估的完整流程。此外,还提到了 TensorFlow 的自动微分、模型可视化和分布式训练等高级特性。通过本文,读者可掌握 TensorFlow 基本用法,为构建高效深度学习模型打下基础。
|
1月前
|
机器学习/深度学习 算法 TensorFlow
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
TensorFlow 2keras开发深度学习模型实例:多层感知器(MLP),卷积神经网络(CNN)和递归神经网络(RNN)
|
1月前
|
机器学习/深度学习 TensorFlow API
Python安装TensorFlow 2、tf.keras和深度学习模型的定义
Python安装TensorFlow 2、tf.keras和深度学习模型的定义