# -*- coding: utf-8 -*- """ Created on Sat Jan 19 13:39:52 2019 @author: Artem Oppermann """ import tensorflow as tf import os from model.inference_model import InferenceModel tf.app.flags.DEFINE_string('checkpoints_path', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'checkpoints/')), 'Path for the test data.') tf.app.flags.DEFINE_string('export_path_base', os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'model-export/')), 'Directory where to export the model.') tf.app.flags.DEFINE_integer('model_version', 1, 'Version number of the model.') tf.app.flags.DEFINE_integer('num_v', 3952, 'Number of visible neurons (Number of movies the users rated.)') FLAGS = tf.app.flags.FLAGS def run_inference(): inference_graph=tf.Graph() with inference_graph.as_default(): model=InferenceModel(FLAGS) input_data=tf.placeholder(tf.float32, shape=[None, 3952]) ratings=model.inference(input_data) saver = tf.train.Saver() with tf.Session(graph=inference_graph) as sess: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoints_path) saver.restore(sess, ckpt.model_checkpoint_path) # Save the model export_path = os.path.join(tf.compat.as_bytes(FLAGS.export_path_base), tf.compat.as_bytes('model_v_%s'%str(FLAGS.model_version))) print('Exporting trained model to %s'%export_path) builder = tf.saved_model.builder.SavedModelBuilder(export_path) # create tensors info predict_tensor_inputs_info = tf.saved_model.utils.build_tensor_info(input_data) predict_tensor_scores_info = tf.saved_model.utils.build_tensor_info(ratings) # build prediction signature prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'inputs': predict_tensor_inputs_info}, outputs={'ratings': predict_tensor_scores_info}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME ) ) # save the model builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ 'predict_ratings': prediction_signature }) builder.save() if __name__ == "__main__": run_inference()
这是一个 Python 脚本,它用于导出经过训练的模型,使其可以在生产环境中进行推理。该脚本首先使用 TensorFlow 的 flags 定义了一些参数,如模型版本号、模型路径、输出目录等等。然后,它创建了一个名为 inference_graph 的 TensorFlow 图,并定义了一个 InferenceModel,该模型用于从输入数据中推断评级。接下来,该脚本定义了一个 TensorFlow 会话,加载了已训练的模型,并通过调用 builder.add_meta_graph_and_variables 方法将模型导出到指定目录。在导出过程中,该脚本还使用了 build_signature_def 方法定义了一个 prediction_signature,以便在推理时使用。最后,该脚本调用 builder.save 方法将导出的模型保存到指定目录中。
这是一个 Python 脚本,主要实现了将已经训练好的推荐系统模型进行导出,以便后续可以部署到生产环境中使用。
代码首先导入了必要的 Python 模块和自定义模块,其中 tensorflow
是关键的模块,因为它是实现深度学习的核心框架。然后定义了四个 TensorFlow flag,分别用于设置 checkpoints 的路径、导出模型的基础路径、模型的版本号和可见神经元的数量(即用户评分的电影数量)。
接着定义了一个 run_inference()
函数,该函数是整个脚本的主要实现部分。在函数内部,首先创建了一个 TensorFlow 计算图 inference_graph
,并使用 with
语句将该计算图设置为默认计算图。接着创建了一个 InferenceModel
对象 model
,该对象定义了推荐系统模型的结构和计算图中的操作。然后定义了一个占位符 input_data
,用于接收输入的测试数据。使用 model.inference()
方法对输入数据进行推断,得到预测的评分 ratings
。接着创建了一个 TensorFlow 的 Saver 对象 saver
,用于加载已经训练好的模型的权重。
然后创建了一个 TensorFlow 会话 sess
,并将之前训练好的模型的 checkpoint 加载到会话中。接着将导出模型的路径设置为 FLAGS.export_path_base
+ 'model_v_' + FLAGS.model_version
,并打印出导出路径。使用 tf.saved_model.builder.SavedModelBuilder()
创建了一个 SavedModelBuilder 对象 builder
,用于保存导出的模型。接着使用 tf.saved_model.utils.build_tensor_info()
函数创建了输入数据 input_data
和预测评分 ratings
的 tensor 信息,并使用 tf.saved_model.signature_def_utils.build_signature_def()
函数创建了一个 prediction signature,表示模型的输入和输出。最后使用 builder.add_meta_graph_and_variables()
方法将计算图和变量添加到 SavedModelBuilder 对象中,并使用 builder.save()
方法将导出的模型保存到硬盘中。
最后,如果该脚本被直接执行,就会调用 run_inference()
函数,开始导出模型。