TensorFlow固化模型+打包程序+web API

简介: 固化Tensorflow模型,使用flask搭建简易web API,打包python代码

TensorFlow固化模型+打包程序+web API

训练过程保存模型

Tensorflow在训练过程中将参数和graph分开保存,例如使用下面的代码:

# -*- coding:utf-8 -*-
import tensorflow as tf
import os

dir = os.path.dirname(os.path.realpath(__file__))

v1 = tf.Variable(1, name='v1')
v2 = tf.placeholder(tf.int32, name='v2')

y = tf.add(v1, v2, name='add')

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print(sess.run(y, feed_dict={v2: 2}))

    save_dir = dir+'/model'
    os.makedirs(save_dir, exist_ok=True)
    saver.save(sess, save_dir+'/model')

会生成4个文件,当然在训练的过程中除了checkpoint,其他三个文件会有多个。

checkpoint
model.data-00000-of-00001
model.index
model.meta

简单描述几个文件:
meta文件是保存图的(包括图,操作等)
data文件是保存数据的(权重)
index文件是一个不可修改的键值表

固化训练好的模型

在训练完成后选择效果最好的模型,进行压缩,或者将graph和权重放在一起以便生产使用。

# -*- coding:utf-8 -*-
import tensorflow as tf
import os

dir = os.path.dirname(os.path.realpath(__file__))
checkpoint = tf.train.get_checkpoint_state(dir + '/model')
input_checkpoint = checkpoint.model_checkpoint_path
print(input_checkpoint)

absolute_model = '/'.join(input_checkpoint.split('/')[:-1])
print(absolute_model)

output_grap = absolute_model + "/frozen_model.pb"
with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)

    saver.restore(sess, input_checkpoint)
    # 打印图中的变量,查看要保存的
    for op in tf.get_default_graph().get_operations():
        print(op.name, op.values())

    output_grap_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                   tf.get_default_graph().as_graph_def(),
                                                                   output_node_names=['add'])
    with tf.gfile.GFile(output_grap, 'wb') as f:
        f.write(output_grap_def.SerializeToString())
    print("%d ops in the final graph." % len(output_grap_def.node))

此时model文件夹下就会多出frozen_model.pb文件

convert_variables_to_constants()函数的作用:

  1. 会将变量替换成常量固化起来
  2. 将前向传播不需要的节点node去掉
    所以output_node_names参数只要输入你的网络的输出,就会生成一个最小的序列化的二进制pb文件。

使用pb(protobuf)模型

# -*- coding:utf-8 -*-
import tensorflow as tf
import argparse
def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
    return graph


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default='frozen_model.pb',
                        type=str, help='Frozen model file to import')
    args = parser.parse_args()

    graph = load_graph(args.frozen_model_filename)

    v2 = graph.get_tensor_by_name('prefix/v2:0')
    add = graph.get_tensor_by_name('prefix/add:0')

    for op in graph.get_operations():
        print(op.name)

    with tf.Session(graph=graph) as sess:

        out = sess.run(add, feed_dict={v2: 10})
        print(out)

打包程序

上面的模型已经打包了,下面对test.py代码进行打包,与上面的不同的地方是将加法的第二个参数预留出来

# -*- coding:utf-8 -*-
import os
os.environ["PBR_VERSION"]='3.1.1'
import argparse
import tensorflow as tf



def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
        return graph


if __name__ == "__main__":
    # 创建一个解析对象
    parser = argparse.ArgumentParser()
    # 向parser对象中添加命令行参数和选项参数
    parser.add_argument('--num', type=int, help='add') # 留出加法的第二个数子作为参数
    parser.add_argument("--frozen_model_filename",
                        default='model/frozen_model.pb',
                        type=str, help='Frozen model file to import')
    # 进行解析
    args = parser.parse_args()

    graph = load_graph(args.frozen_model_filename)
    v2 = graph.get_tensor_by_name('prefix/v2:0')
    add = graph.get_tensor_by_name('prefix/add:0')

    with tf.Session(graph=graph) as sess:
        out = sess.run(add, feed_dict={v2: args.num})
        print(out)

使用 python test.py --num=10
输出 11

# 安装pyinstaller
# pip install pyinstaller
# -F 是 --onefile的缩写
# --clean 是清理临时文件,也就是build文件夹下的临时文件
pyinstaller -F  --clean test.py

完成后到dist文件夹下

./test --num=10

输出11
打包遇到的问题:

如果没有os.environ["PBR_VERSION"]='3.1.1'会报错

Traceback (most recent call last):
  File "pack_tf_add.py", line 4, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/__init__.py", line 24, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/python/__init__.py", line 104, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/tensorflow/python/platform/test.py", line 53, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/mock/__init__.py", line 2, in <module>
  File "/private/var/folders/88/1jw_0lt50tsb4n08mg_493040000gn/T/pip-build-3m08rf/pyinstaller/PyInstaller/loader/pyimod03_importers.py", line 396, in load_module
  File "site-packages/mock/mock.py", line 71, in <module>
  File "site-packages/pbr/version.py", line 462, in semantic_version
  File "site-packages/pbr/version.py", line 449, in _get_version_from_pkg_resources
  File "site-packages/pbr/packaging.py", line 812, in get_version
Exception: Versioning for this project requires either an sdist tarball, or access to an upstream git repository. It's also possible that there is a mismatch between the package name in setup.cfg and the argument given to pbr.version.VersionInfo. Project name mock was given, but was not able to be found.

解决方法:

https://blog.csdn.net/laocaibcc229/article/details/78570017
https://github.com/pyinstaller/pyinstaller/issues/2883

# 添加到首行
import os
os.environ["PBR_VERSION"]='3.1.1' #要去查询自己的版本

查看pbr版本

pbr --version # 3.1.1

web API

使用flask搭建一个微型web API

# -*- coding:utf-8 -*-
import argparse
from flask import Flask
from flask import request
import tensorflow as tf

app = Flask(__name__)

def load_graph(frozen_graph_file):
    with tf.gfile.GFile(frozen_graph_file, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
        return graph
    
    
@app.route('/', methods=['POST', 'GET'])
def about():
    if request.method == "POST":
        print("in post")
        num = request.form.get('num')
        y_out = persistent_sess.run(y, feed_dict={x: num})

        return str(y_out)
    else:
        return """<form action="/" method="POST">
                  <input type="text" name="num" placeholder="Enter num">
                  <input type="submit" value="Submit" name="ok"/>
                  </form>"""


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default="frozen_model.pb", type=str,
                        help="Frozen model file to import")
    parser.add_argument("--gpu_memory", default=.2, type=float, help="GPU memory per process")
    args = parser.parse_args()

    print('Loading the model')
    graph = load_graph(args.frozen_model_filename)
    x = graph.get_tensor_by_name('prefix/v2:0')
    y = graph.get_tensor_by_name('prefix/add:0')
    # use gpu
    # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory)
    # sess_config = tf.ConfigProto(gpu_options=gpu_options)
    # persistent_sess = tf.Session(graph=graph, config=sess_config)

    # use cpu
    persistent_sess = tf.Session(graph=graph)
    print('Starting the API')
    app.run()

点击 http://127.0.0.1:5000/ 输入数字

点击 submit 显示结果

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
2月前
|
Java API 数据库
构建RESTful API已经成为现代Web开发的标准做法之一。Spring Boot框架因其简洁的配置、快速的启动特性及丰富的功能集而备受开发者青睐。
【10月更文挑战第11天】本文介绍如何使用Spring Boot构建在线图书管理系统的RESTful API。通过创建Spring Boot项目,定义`Book`实体类、`BookRepository`接口和`BookService`服务类,最后实现`BookController`控制器来处理HTTP请求,展示了从基础环境搭建到API测试的完整过程。
60 4
|
3天前
|
存储 人工智能 API
AgentScope:阿里开源多智能体低代码开发平台,支持一键导出源码、多种模型API和本地模型部署
AgentScope是阿里巴巴集团开源的多智能体开发平台,旨在帮助开发者轻松构建和部署多智能体应用。该平台提供分布式支持,内置多种模型API和本地模型部署选项,支持多模态数据处理。
56 4
AgentScope:阿里开源多智能体低代码开发平台,支持一键导出源码、多种模型API和本地模型部署
|
15天前
|
Kubernetes 安全 Devops
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
39 10
有效抵御网络应用及API威胁,聊聊F5 BIG-IP Next Web应用防火墙
|
26天前
|
监控 前端开发 JavaScript
使用 MERN 堆栈构建可扩展 Web 应用程序的最佳实践
使用 MERN 堆栈构建可扩展 Web 应用程序的最佳实践
30 6
|
1月前
|
人工智能 Java API
ChatClient:探索与AI模型通信的Fluent API
【11月更文挑战第22天】随着人工智能(AI)技术的飞速发展,越来越多的应用场景开始融入AI技术以提升用户体验和系统效率。在Java开发中,与AI模型通信成为了一个重要而常见的需求。为了满足这一需求,Spring AI引入了ChatClient,一个提供流畅API(Fluent API)的客户端,用于与各种AI模型进行通信。本文将深入探讨ChatClient的底层原理、业务场景、概念、功能点,并通过Java代码示例展示如何使用Fluent API与AI模型进行通信。
50 8
|
1月前
|
开发框架 搜索推荐 数据可视化
Django框架适合开发哪种类型的Web应用程序?
Django 框架凭借其强大的功能、稳定性和可扩展性,几乎可以适应各种类型的 Web 应用程序开发需求。无论是简单的网站还是复杂的企业级系统,Django 都能提供可靠的支持,帮助开发者快速构建高质量的应用。同时,其活跃的社区和丰富的资源也为开发者在项目实施过程中提供了有力的保障。
|
1月前
|
前端开发 API 开发者
Python Web开发者必看!AJAX、Fetch API实战技巧,让前后端交互如丝般顺滑!
在Web开发中,前后端的高效交互是提升用户体验的关键。本文通过一个基于Flask框架的博客系统实战案例,详细介绍了如何使用AJAX和Fetch API实现不刷新页面查看评论的功能。从后端路由设置到前端请求处理,全面展示了这两种技术的应用技巧,帮助Python Web开发者提升项目质量和开发效率。
53 1
|
1月前
|
JSON API 数据格式
如何使用Python和Flask构建一个简单的RESTful API。Flask是一个轻量级的Web框架
本文介绍了如何使用Python和Flask构建一个简单的RESTful API。Flask是一个轻量级的Web框架,适合小型项目和微服务。文章从环境准备、创建基本Flask应用、定义资源和路由、请求和响应处理、错误处理等方面进行了详细说明,并提供了示例代码。通过这些步骤,读者可以快速上手构建自己的RESTful API。
102 2
|
1月前
|
数据可视化 数据库 开发者
使用Dash构建交互式Web应用程序
【10月更文挑战第16天】本文介绍了使用Python的Dash框架构建交互式Web应用程序的方法。Dash结合了Flask、React和Plotly等技术,让开发者能够快速创建功能丰富的数据可视化应用。文章从安装Dash开始,逐步介绍了创建简单应用程序、添加交互元素、部署应用程序以及集成更多功能的步骤,并提供了代码示例。通过本文,读者可以掌握使用Dash构建交互式Web应用程序的基本技巧和高级功能。
|
2月前
|
监控 负载均衡 API
Web、RESTful API 在微服务中有哪些作用?
在微服务架构中,Web 和 RESTful API 扮演着至关重要的角色。它们帮助实现服务之间的通信、数据交换和系统的可扩展性。
60 2

热门文章

最新文章