TensorFlow固化模型+打包程序+web API

简介: 固化Tensorflow模型,使用flask搭建简易web API,打包python代码

TensorFlow固化模型+打包程序+web API

训练过程保存模型

Tensorflow在训练过程中将参数和graph分开保存,例如使用下面的代码:

# -*- coding:utf-8 -*-
import tensorflow as tf
import os

dir = os.path.dirname(os.path.realpath(__file__))

v1 = tf.Variable(1, name='v1')
v2 = tf.placeholder(tf.int32, name='v2')

y = tf.add(v1, v2, name='add')

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print(sess.run(y, feed_dict={v2: 2}))

    save_dir = dir+'/model'
    os.makedirs(save_dir, exist_ok=True)
    saver.save(sess, save_dir+'/model')

会生成4个文件,当然在训练的过程中除了checkpoint,其他三个文件会有多个。

checkpoint
model.data-00000-of-00001
model.index
model.meta

简单描述几个文件:
meta文件是保存图的(包括图,操作等)
data文件是保存数据的(权重)
index文件是一个不可修改的键值表

固化训练好的模型

在训练完成后选择效果最好的模型,进行压缩,或者将graph和权重放在一起以便生产使用。

# -*- coding:utf-8 -*-
import tensorflow as tf
import os

dir = os.path.dirname(os.path.realpath(__file__))
checkpoint = tf.train.get_checkpoint_state(dir + '/model')
input_checkpoint = checkpoint.model_checkpoint_path
print(input_checkpoint)

absolute_model = '/'.join(input_checkpoint.split('/')[:-1])
print(absolute_model)

output_grap = absolute_model + "/frozen_model.pb"
with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)

    saver.restore(sess, input_checkpoint)
    # 打印图中的变量,查看要保存的
    for op in tf.get_default_graph().get_operations():
        print(op.name, op.values())

    output_grap_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                   tf.get_default_graph().as_graph_def(),
                                                                   output_node_names=['add'])
    with tf.gfile.GFile(output_grap, 'wb') as f:
        f.write(output_grap_def.SerializeToString())
    print("%d ops in the final graph." % len(output_grap_def.node))

此时model文件夹下就会多出frozen_model.pb文件

convert_variables_to_constants()函数的作用:

  1. 会将变量替换成常量固化起来
  2. 将前向传播不需要的节点node去掉
    所以output_node_names参数只要输入你的网络的输出,就会生成一个最小的序列化的二进制pb文件。

使用pb(protobuf)模型

# -*- coding:utf-8 -*-
import tensorflow as tf
import argparse
def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
    return graph


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default='frozen_model.pb',
                        type=str, help='Frozen model file to import')
    args = parser.parse_args()

    graph = load_graph(args.frozen_model_filename)

    v2 = graph.get_tensor_by_name('prefix/v2:0')
    add = graph.get_tensor_by_name('prefix/add:0')

    for op in graph.get_operations():
        print(op.name)

    with tf.Session(graph=graph) as sess:

        out = sess.run(add, feed_dict={v2: 10})
        print(out)

打包程序

上面的模型已经打包了,下面对test.py代码进行打包,与上面的不同的地方是将加法的第二个参数预留出来

# -*- coding:utf-8 -*-
import os
os.environ["PBR_VERSION"]='3.1.1'
import argparse
import tensorflow as tf



def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
        return graph


if __name__ == "__main__":
    # 创建一个解析对象
    parser = argparse.ArgumentParser()
    # 向parser对象中添加命令行参数和选项参数
    parser.add_argument('--num', type=int, help='add') # 留出加法的第二个数子作为参数
    parser.add_argument("--frozen_model_filename",
                        default='model/frozen_model.pb',
                        type=str, help='Frozen model file to import')
    # 进行解析
    args = parser.parse_args()

    graph = load_graph(args.frozen_model_filename)
    v2 = graph.get_tensor_by_name('prefix/v2:0')
    add = graph.get_tensor_by_name('prefix/add:0')

    with tf.Session(graph=graph) as sess:
        out = sess.run(add, feed_dict={v2: args.num})
        print(out)

使用 python test.py --num=10
输出 11

# 安装pyinstaller
# pip install pyinstaller
# -F 是 --onefile的缩写
# --clean 是清理临时文件,也就是build文件夹下的临时文件
pyinstaller -F  --clean test.py

完成后到dist文件夹下

./test --num=10

输出11
打包遇到的问题:

如果没有os.environ["PBR_VERSION"]='3.1.1'会报错

Traceback (most recent call last):
  File "pack_tf_add.py", line 4, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/__init__.py", line 24, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/python/__init__.py", line 104, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/python/platform/test.py", line 53, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/mock/__init__.py", line 2, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/mock/mock.py", line 71, in <module>
  File "site-packages/pbr/version.py", line 462, in semantic_version
  File "site-packages/pbr/version.py", line 449, in _get_version_from_pkg_resources
  File "site-packages/pbr/packaging.py", line 812, in get_version
Exception: Versioning for this project requires either an sdist tarball, or access to an upstream git repository. It's also possible that there is a mismatch between the package name in setup.cfg and the argument given to pbr.version.VersionInfo. Project name mock was given, but was not able to be found.

解决方法:

https://blog.csdn.net/laocaibcc229/article/details/78570017
https://github.com/pyinstaller/pyinstaller/issues/2883

# 添加到首行
import os
os.environ["PBR_VERSION"]='3.1.1' #要去查询自己的版本

查看pbr版本

pbr --version # 3.1.1

web API

使用flask搭建一个微型web API

# -*- coding:utf-8 -*-
import argparse
from flask import Flask
from flask import request
import tensorflow as tf

app = Flask(__name__)

def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
        return graph
    
    
@app.route('/', methods=['POST', 'GET'])
def about():
    if request.method == "POST":
        print("in post")
        num = request.form.get('num')
        y_out = persistent_sess.run(y, feed_dict={x: num})

        return str(y_out)
    else:
        return """<form action="/" method="POST">
                  <input type="text" name="num" placeholder="Enter num">
                  <input type="submit" value="Submit" name="ok"/>
                  </form>"""


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default="frozen_model.pb", type=str,
                        help="Frozen model file to import")
    parser.add_argument("--gpu_memory", default=.2, type=float, help="GPU memory per process")
    args = parser.parse_args()

    print('Loading the model')
    graph = load_graph(args.frozen_model_filename)
    x = graph.get_tensor_by_name('prefix/v2:0')
    y = graph.get_tensor_by_name('prefix/add:0')
    # use gpu
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory)
    # sess_config = tf.ConfigProto(gpu_options=gpu_options)
    # persistent_sess = tf.Session(graph=graph, config=sess_config)

    # use cpu
    persistent_sess = tf.Session(graph=graph)
    print('Starting the API')
    app.run()

点击 http://127.0.0.1:5000/ 输入数字

点击 submit 显示结果

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
1月前
|
人工智能 前端开发 API
Gemini Coder:基于 Google Gemini API 的开源 Web 应用生成工具,支持实时编辑和预览
Gemini Coder 是一款基于 Google Gemini API 的 AI 应用生成工具,支持通过文本描述快速生成代码,并提供实时代码编辑和预览功能,简化开发流程。
139 38
Gemini Coder:基于 Google Gemini API 的开源 Web 应用生成工具,支持实时编辑和预览
|
4月前
|
Java API 数据库
构建RESTful API已经成为现代Web开发的标准做法之一。Spring Boot框架因其简洁的配置、快速的启动特性及丰富的功能集而备受开发者青睐。
【10月更文挑战第11天】本文介绍如何使用Spring Boot构建在线图书管理系统的RESTful API。通过创建Spring Boot项目,定义`Book`实体类、`BookRepository`接口和`BookService`服务类,最后实现`BookController`控制器来处理HTTP请求,展示了从基础环境搭建到API测试的完整过程。
77 4
|
3月前
|
开发框架 搜索推荐 数据可视化
Django框架适合开发哪种类型的Web应用程序?
Django 框架凭借其强大的功能、稳定性和可扩展性,几乎可以适应各种类型的 Web 应用程序开发需求。无论是简单的网站还是复杂的企业级系统,Django 都能提供可靠的支持,帮助开发者快速构建高质量的应用。同时,其活跃的社区和丰富的资源也为开发者在项目实施过程中提供了有力的保障。
156 62
|
2月前
|
Kubernetes 安全 Devops
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
97 10
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
|
3月前
|
监控 前端开发 JavaScript
使用 MERN 堆栈构建可扩展 Web 应用程序的最佳实践
使用 MERN 堆栈构建可扩展 Web 应用程序的最佳实践
57 6
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
335 5
|
3月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
179 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
171 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
3月前
|
前端开发 API 开发者
Python Web开发者必看!AJAX、Fetch API实战技巧,让前后端交互如丝般顺滑!
在Web开发中,前后端的高效交互是提升用户体验的关键。本文通过一个基于Flask框架的博客系统实战案例,详细介绍了如何使用AJAX和Fetch API实现不刷新页面查看评论的功能。从后端路由设置到前端请求处理,全面展示了这两种技术的应用技巧,帮助Python Web开发者提升项目质量和开发效率。
87 1
|
3月前
|
JSON API 数据格式
如何使用Python和Flask构建一个简单的RESTful API。Flask是一个轻量级的Web框架
本文介绍了如何使用Python和Flask构建一个简单的RESTful API。Flask是一个轻量级的Web框架,适合小型项目和微服务。文章从环境准备、创建基本Flask应用、定义资源和路由、请求和响应处理、错误处理等方面进行了详细说明,并提供了示例代码。通过这些步骤,读者可以快速上手构建自己的RESTful API。
248 2

热门文章

最新文章