inference.py的代码解释

在线体验各类最新模型,更有模型 免费Token 额度领取!
立即体验
简介: 这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 19 13:39:52 2019
@author: Artem Oppermann
"""
import tensorflow as tf
import os
from model.inference_model import InferenceModel
tf.app.flags.DEFINE_string('checkpoints_path', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'checkpoints/')), 
                           'Path for the test data.')
tf.app.flags.DEFINE_string('export_path_base', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'model-export/')), 
                           'Directory where to export the model.')
tf.app.flags.DEFINE_integer('model_version', 1, 'Version number of the model.')
tf.app.flags.DEFINE_integer('num_v', 3952,
                            'Number of visible neurons (Number of movies the users rated.)')
FLAGS = tf.app.flags.FLAGS
def run_inference():
    inference_graph=tf.Graph()
    with inference_graph.as_default():
        model=InferenceModel(FLAGS)
        input_data=tf.placeholder(tf.float32, shape=[None, 3952])  
        ratings=model.inference(input_data)
        saver = tf.train.Saver()
    with tf.Session(graph=inference_graph) as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoints_path)   
        saver.restore(sess, ckpt.model_checkpoint_path)      
        # Save the model
        export_path = os.path.join(tf.compat.as_bytes(FLAGS.export_path_base),
                                   tf.compat.as_bytes('model_v_%s'%str(FLAGS.model_version)))
        print('Exporting trained model to %s'%export_path)
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        # create tensors info
        predict_tensor_inputs_info = tf.saved_model.utils.build_tensor_info(input_data)
        predict_tensor_scores_info = tf.saved_model.utils.build_tensor_info(ratings)
        # build prediction signature
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'inputs': predict_tensor_inputs_info},
                outputs={'ratings': predict_tensor_scores_info},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            )
        )
        # save the model
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict_ratings': prediction_signature
            })
        builder.save()
if __name__ == "__main__":
    run_inference() 

这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。接下来,该脚本定义了一个 TensorFlow 会话,加载了已训练的模型,并通过调用 builder.add_meta_graph_and_variables 方法将模型导出到指定目录。在导出过程中,该脚本还使用了 build_signature_def 方法定义了一个 prediction_signature,以便在推理时使用。最后,该脚本调用 builder.save 方法将导出的模型保存到指定目录中。


这是一个 Python 脚本,主要实现了将已经训练好的推荐系统模型进行导出,以便后续可以部署到生产环境中使用。

代码首先导入了必要的 Python 模块和自定义模块,其中 tensorflow 是关键的模块,因为它是实现深度学习的核心框架。然后定义了四个 TensorFlow flag,分别用于设置 checkpoints 的路径、导出模型的基础路径、模型的版本号和可见神经元的数量(即用户评分的电影数量)。

接着定义了一个 run_inference() 函数,该函数是整个脚本的主要实现部分。在函数内部,首先创建了一个 TensorFlow 计算图 inference_graph,并使用 with 语句将该计算图设置为默认计算图。接着创建了一个 InferenceModel 对象 model,该对象定义了推荐系统模型的结构和计算图中的操作。然后定义了一个占位符 input_data,用于接收输入的测试数据。使用 model.inference() 方法对输入数据进行推断,得到预测的评分 ratings。接着创建了一个 TensorFlow 的 Saver 对象 saver,用于加载已经训练好的模型的权重。

然后创建了一个 TensorFlow 会话 sess,并将之前训练好的模型的 checkpoint 加载到会话中。接着将导出模型的路径设置为 FLAGS.export_path_base + 'model_v_' + FLAGS.model_version,并打印出导出路径。使用 tf.saved_model.builder.SavedModelBuilder() 创建了一个 SavedModelBuilder 对象 builder,用于保存导出的模型。接着使用 tf.saved_model.utils.build_tensor_info() 函数创建了输入数据 input_data 和预测评分 ratings 的 tensor 信息,并使用 tf.saved_model.signature_def_utils.build_signature_def() 函数创建了一个 prediction signature,表示模型的输入和输出。最后使用 builder.add_meta_graph_and_variables() 方法将计算图和变量添加到 SavedModelBuilder 对象中,并使用 builder.save() 方法将导出的模型保存到硬盘中。

最后,如果该脚本被直接执行,就会调用 run_inference() 函数,开始导出模型。

相关文章
|
3月前
|
安全 IDE Shell
【全网最详细】Python3.10下载安装使用保姆级教程(看不懂打我)
Python 3.10(2021年10月发布)是已结束支持的旧LTS版本,以稳定性著称。核心新特性为结构模式匹配,显著提升数据处理能力;同时优化错误提示、类型注解等。适合维护老项目或对稳定性要求高的生产环境。
|
Web App开发 测试技术 API
使用Ollama和Botnow本地部署DeepSeek R1模型的对比分析
本文详细对比了使用Ollama和Botnow两种方式在本地运行DeepSeek R1等开源大模型的不同。通过Ollama,用户可以在个人电脑(如MacBook Pro)上快速部署和测试模型;而Botnow则提供了企业级的API接入和本地部署方案,支持更复杂的应用场景。具体步骤包括环境准备、模型下载与运行、图形化界面操作等,帮助用户选择最适合自己的方式体验大模型的强大功能。
1054 0
|
机器学习/深度学习 TensorFlow 算法框架/工具
YOLOv11改进策略【卷积层】| 利用MobileNetv4中的UIB、ExtraDW优化C3k2
YOLOv11改进策略【卷积层】| 利用MobileNetv4中的UIB、ExtraDW优化C3k2
1068 0
YOLOv11改进策略【卷积层】| 利用MobileNetv4中的UIB、ExtraDW优化C3k2
|
机器学习/深度学习 人工智能 自然语言处理
三行代码实现实时语音转文本,支持自动断句和语音唤醒,用 RealtimeSTT 轻松创建高效语音 AI 助手
RealtimeSTT 是一款开源的实时语音转文本库,支持低延迟应用,具备语音活动检测、唤醒词激活等功能,适用于语音助手、实时字幕等场景。
3261 18
三行代码实现实时语音转文本,支持自动断句和语音唤醒,用 RealtimeSTT 轻松创建高效语音 AI 助手
|
人工智能 机器人 Linux
把大模型变成微信私人助手,三步搞定!
随着大模型的应用越来越广泛,相信大家都对拥有一个自己的私人AI助手越来越感兴趣。然而基于大模型遵循的"规模效应"(Scaling Law)原理,传统部署方式面临三重阻碍:高昂的运维成本、复杂的技术门槛(需掌握模型部署、量化等技术概念)以及系统集成难题。
1553 0
|
机器学习/深度学习 自然语言处理 算法
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
深度学习基础知识:介绍深度学习的发展历程、基本概念和主要应用
7541 0
|
人工智能 自然语言处理 搜索推荐
《深度剖析:开源与闭源模型,AI舞台上的不同角色》
在人工智能领域,开源与闭源模型各有优劣。闭源模型由大公司精心打造,初始性能优越,但优化受限;开源模型则依靠社区力量,灵活性高、迭代迅速,长期潜力大。在学术研究中,开源模型透明性高,利于创新;商业应用上,闭源模型稳定性强,适合高要求场景。资源受限环境中,开源模型更易裁剪优化。企业和开发者应根据需求选择合适模型,两者共同推动AI发展。
3038 9
|
人工智能 自然语言处理 并行计算
探索大模型部署:基于 VLLM 和 ModelScope 与 Qwen2.5 在双 32G VGPU 上的实践之旅
本文介绍了使用 `VLLM` 和 `ModelScope` 部署 `Qwen2.5` 大模型的实践过程,包括环境搭建、模型下载和在双 32G VGPU 上的成功部署,展现了高性能计算与大模型结合的强大力量。
4058 3
|
移动开发 小程序 API
uniapp组件库Card 卡片 的使用方法
uniapp组件库Card 卡片 的使用方法
1044 1
|
存储 编译器 C语言
C与C++之间相互调用的基本方法
C与C++之间相互调用的基本方法
406 1

热门文章

最新文章