机器学习PAI怎么把frozen graph保存成 tf serving呢?就是执行完optimize_for_inference得到frozen_graph_def ,怎么再保存成model = tf.saved_model.load可以直接加载的模型
在机器学习 PAI 中,将优化后的冻结图保存为 TensorFlow Serving 可加载的模型可以通过以下步骤完成:
tf.io.gfile.GFile
方法加载优化后的冻结图文件,并将其反序列化为 tf.GraphDef
对象。import tensorflow as tf
frozen_graph_file = 'optimized_graph.pb'
# 加载冻结图
with tf.io.gfile.GFile(frozen_graph_file, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.compat.v1.Session() as sess:
# 导入冻结图
tf.import_graph_def(graph_def, name='')
tf.saved_model.Builder
对象:tf.saved_model.Builder
对象,并指定要保存的模型版本号和路径。saved_model_path = 'saved_model/1' # 模型保存路径,版本号为 1
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(saved_model_path)
inputs = {'input_node': sess.graph.get_tensor_by_name('input_node:0')}
outputs = {'output_node': sess.graph.get_tensor_by_name('output_node:0')}
signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs, outputs,
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
)
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={'serving_default': signature_def}
)
builder.save()
方法将模型保存到指定路径。builder.save()
执行完上述步骤后,你将会在指定的路径下得到一个 TensorFlow Serving 可加载的模型。该模型可以使用 TensorFlow Serving 提供的 API 进行部署和预测。
请注意,具体的代码实现可能会因 TensorFlow 版本的不同而有所差异。上述示例代码基于 TensorFlow 2.x 进行编写。如果你使用的是 TensorFlow 1.x,请根据相应的 API 进行调整。
第一种,你可以直接使用得到的frozen graph推理;可以参考,可能需要改改:import tensorflow as tf
graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile("/path/to/frozen/graph.pb", "rb") as f:
graph_def.ParseFromString(f.read())
with tf.compat.v1.Session() as sess:
tf.import_graph_def(graph_def)
# 获取输入和输出的Tensor
input_tensor = sess.graph.get_tensor_by_name("input:0")
output_tensor = sess.graph.get_tensor_by_name("output:0")
# 执行推理
output = sess.run(output_tensor, feed_dict={input_tensor: [[1.0, 2.0, 3.0, 4.0]]})
第二种;需要将frozen graph转换成saved model;可以参考可能需要修改:import tensorflow as tf
optimized_graph_def = tf.GraphDef()
with tf.io.gfile.GFile("/path/to/optimized/graph.pb", "rb") as f:
optimized_graph_def.ParseFromString(f.read())
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder("/path/to/saved/model")
inputs = {
"input": input_tensor_info
}
outputs = {
"output": output_tensor_info
}
signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME)
with tf.compat.v1.Session() as sess:
# 导入GraphDef
tf.import_graph_def(optimized_graph_def)
# 添加图形和变量
builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature
}
)
builder.save()7月11日 21:23,此回答整理自钉群“【EasyRec】推荐算法交流群”
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。
人工智能平台 PAI(Platform for AI,原机器学习平台PAI)是面向开发者和企业的机器学习/深度学习工程平台,提供包含数据标注、模型构建、模型训练、模型部署、推理优化在内的AI开发全链路服务,内置140+种优化算法,具备丰富的行业场景插件,为用户提供低门槛、高性能的云原生AI工程化能力。