inference.py的代码解释

简介: 这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 19 13:39:52 2019
@author: Artem Oppermann
"""
import tensorflow as tf
import os
from model.inference_model import InferenceModel
tf.app.flags.DEFINE_string('checkpoints_path', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'checkpoints/')), 
                           'Path for the test data.')
tf.app.flags.DEFINE_string('export_path_base', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'model-export/')), 
                           'Directory where to export the model.')
tf.app.flags.DEFINE_integer('model_version', 1, 'Version number of the model.')
tf.app.flags.DEFINE_integer('num_v', 3952,
                            'Number of visible neurons (Number of movies the users rated.)')
FLAGS = tf.app.flags.FLAGS
def run_inference():
    inference_graph=tf.Graph()
    with inference_graph.as_default():
        model=InferenceModel(FLAGS)
        input_data=tf.placeholder(tf.float32, shape=[None, 3952])  
        ratings=model.inference(input_data)
        saver = tf.train.Saver()
    with tf.Session(graph=inference_graph) as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoints_path)   
        saver.restore(sess, ckpt.model_checkpoint_path)      
        # Save the model
        export_path = os.path.join(tf.compat.as_bytes(FLAGS.export_path_base),
                                   tf.compat.as_bytes('model_v_%s'%str(FLAGS.model_version)))
        print('Exporting trained model to %s'%export_path)
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        # create tensors info
        predict_tensor_inputs_info = tf.saved_model.utils.build_tensor_info(input_data)
        predict_tensor_scores_info = tf.saved_model.utils.build_tensor_info(ratings)
        # build prediction signature
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'inputs': predict_tensor_inputs_info},
                outputs={'ratings': predict_tensor_scores_info},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            )
        )
        # save the model
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict_ratings': prediction_signature
            })
        builder.save()
if __name__ == "__main__":
    run_inference() 

这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。接下来,该脚本定义了一个 TensorFlow 会话,加载了已训练的模型,并通过调用 builder.add_meta_graph_and_variables 方法将模型导出到指定目录。在导出过程中,该脚本还使用了 build_signature_def 方法定义了一个 prediction_signature,以便在推理时使用。最后,该脚本调用 builder.save 方法将导出的模型保存到指定目录中。


这是一个 Python 脚本,主要实现了将已经训练好的推荐系统模型进行导出,以便后续可以部署到生产环境中使用。

代码首先导入了必要的 Python 模块和自定义模块,其中 tensorflow 是关键的模块,因为它是实现深度学习的核心框架。然后定义了四个 TensorFlow flag,分别用于设置 checkpoints 的路径、导出模型的基础路径、模型的版本号和可见神经元的数量(即用户评分的电影数量)。

接着定义了一个 run_inference() 函数,该函数是整个脚本的主要实现部分。在函数内部,首先创建了一个 TensorFlow 计算图 inference_graph,并使用 with 语句将该计算图设置为默认计算图。接着创建了一个 InferenceModel 对象 model,该对象定义了推荐系统模型的结构和计算图中的操作。然后定义了一个占位符 input_data,用于接收输入的测试数据。使用 model.inference() 方法对输入数据进行推断,得到预测的评分 ratings。接着创建了一个 TensorFlow 的 Saver 对象 saver,用于加载已经训练好的模型的权重。

然后创建了一个 TensorFlow 会话 sess,并将之前训练好的模型的 checkpoint 加载到会话中。接着将导出模型的路径设置为 FLAGS.export_path_base + 'model_v_' + FLAGS.model_version,并打印出导出路径。使用 tf.saved_model.builder.SavedModelBuilder() 创建了一个 SavedModelBuilder 对象 builder,用于保存导出的模型。接着使用 tf.saved_model.utils.build_tensor_info() 函数创建了输入数据 input_data 和预测评分 ratings 的 tensor 信息,并使用 tf.saved_model.signature_def_utils.build_signature_def() 函数创建了一个 prediction signature,表示模型的输入和输出。最后使用 builder.add_meta_graph_and_variables() 方法将计算图和变量添加到 SavedModelBuilder 对象中,并使用 builder.save() 方法将导出的模型保存到硬盘中。

最后,如果该脚本被直接执行,就会调用 run_inference() 函数,开始导出模型。

相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(三)
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(三)
|
6月前
|
机器学习/深度学习 编解码 PyTorch
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(四)
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(四)
|
6月前
|
机器学习/深度学习 编解码 计算机视觉
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(一)
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(一)
|
6月前
|
机器学习/深度学习 编解码 PyTorch
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(二)
深入 YOLOv8:探索 block.py 中的模块,逐行代码分析(二)
|
程序员 开发者 Python
#PY小贴士# py2 和 py3 的差别到底有多大?
虽然结论已经很明确,但我还是想客观地说一句:对于学习者来说,学 py2 还是 py3,真的没有太大差别。之所以这会成为一个问题
|
机器学习/深度学习 测试技术 TensorFlow
dataset.py代码解释
这段代码主要定义了三个函数来创建 TensorFlow 数据集对象,这些数据集对象将被用于训练、评估和推断神经网络模型。
127 0
|
6月前
|
索引
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
265 0
yolov5--detect.py --v5.0版本-最新代码详细解释-2021-6-29号更新
|
6月前
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
yolov5--datasets.py --v5.0版本-数据集加载 最新代码详细解释2021-7-5更新
277 0
|
6月前
|
机器学习/深度学习 索引
yolov5--loss.py --v5.0版本-最新代码详细解释-2021-7-1更新
yolov5--loss.py --v5.0版本-最新代码详细解释-2021-7-1更新
297 0
|
数据可视化 PyTorch 计算机视觉
YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py
YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py
3255 4
YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py