使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法

简介: 使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法写在前面最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法
写在前面
最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。关于这方面的内容网上并不是很多,我也是费了很多周折才完成任务的,这里来总结一下具体有哪些方法可以用,这些方法又有哪些缺陷,以供大家学习交流。

一、使用Java深度学习框架直接部署
(1)使用TensorFlow Java API部署TensorFlow模型
如果我们使用的是TensorFlow训练的模型,那么我们就可以直接使用Java中的TensorFlow API调用模型。这里需要注意的是我们得把训练好的模型保存为.pb格式的文件。具体代码如下:

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["quest_out"])

写入序列化的 PB 文件

with tf.gfile.FastGFile('/home/amax/zth/qa/new_model2_cpu.pb', mode='wb') as f:

f.write(constant_graph.SerializeToString())

然后我们需要在Java使用这个保存好的模型,在pom.xml中引入TensorFlow的依赖

        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>1.11.0</version>
    </dependency>

    <!-- https://mvnrepository.com/artifact/org.tensorflow/libtensorflow_jni_gpu -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>libtensorflow_jni_gpu</artifactId>
        <version>1.11.0</version>
    </dependency>

导包成功后,在Java中调用模型

graphDef = readAllBytes(new FileInputStream(new_model2_cpu.pb));
graph = new Graph();
graph.importGraphDef(graphDef);
session = new Session(graph);
Tensor result = session.runner()

            .feed("ori_quest_embedding", Tensor.create(wordVecInputSentence))//输入你自己的数据
            .feed("dropout", Tensor.create(1.0F))
            .fetch("quest_out") //和上面python保存模型时的output_node_names对应
            .run().get(0);

//这样就能得到模型的输出结果了

(2)使用Deeplearning4J Java API部署Keras模型
如果我们使用的是Keras训练的模型,那么你就可以选择Deeplearning4J 这个框架来调用模型。
第一步同样是使用Keras保存训练好的模型

filepath = "query_models"

checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True)

callback_lists = [checkpoint]

model.fit(x, y, epochs=1,validation_split=0.2,callbacks=callback_lists)

然后同样是Java项目中pom.xml导入Deeplearning4J 依赖

<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta2</version>    


<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-modelimport</artifactId>      
<version>1.0.0-beta2</version>    

库导入成功后,直接使用Java调用保存好的模型

MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(“query_models”);

这样模型就部署成功了,然后关于怎么使用模型这里就不多说了。
注意:这里需要注意的是Deeplearning4J 只支持部分深度学习模型,有些模型是不支持的,譬如我这里使用的GRU模型就不支持,运行上面代码会出现以下错误。明确指明不支持GRU模型

去Deeplearning4J 官网查询发现确实现在不支持GRU模型,以下是官网截图

所以如果你想使用Deeplearning4J 来部署训练好的模型,请先查看下是否支持你所使用的模型。

二、使用Python编写服务端
(1)使用socket实现进程间的通信
用python构建服务端,然后通过Java向服务端发送请求调用模型,第一种是使用socket实现进程中的通信,代码如下:

import socket
import sys
import threading
import json
import numpy as np
import jieba
import os
import numpy as np
import nltk
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM,GRU,TimeDistributed
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
from gensim.models.word2vec import Word2Vec
from keras.optimizers import Adam
from keras.models import load_model
import pickle

nn=network.getNetWork()

cnn = conv.main(False)

深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中

w2v_model = Word2Vec.load("word2vec.w2v").wv
UNK = pickle.load(open('unk.pkl','rb'))
model = load_model('query_models')
a = np.zeros((1, 223,200))
model.predict(a)

def test_init(string):

cut_list = jieba.lcut(string)
x_test = np.zeros((1, 223,200))
for i in range(223):
    x_test[0,i,:] = UNK
for i in range(len(cut_list)):
    if cut_list[i] in w2v_model:
        x_test[0,i,:] = w2v_model.wv[cut_list[i]]   
return x_test,len(cut_list)

string_list = list()
def query_complet(string):

x_test,length = test_init(string)
y = model.predict(x_test)
if length>8:
    return
word1 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[0][0]
word2 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[1][0]
if word1 == '?' or word1 == '?':
    string_list.append(string)
else:
    new_str = string+word1
    query_complet(new_str)

if word2 == '?' or word2 == '?':
    string_list.append(string)
else:
    new_str = string+word2
    query_complet(new_str)
    

def new_query_complet(string):

query_complet(string)
return string_list

def main():

# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind(("172.17.169.232",port))
# 设置监听最大连接数
serversocket.listen(5)
# 获取本地服务器的连接信息
myaddr = serversocket.getsockname()
print("服务器地址:%s"%str(myaddr))
# 循环等待接受客户端信息
while True:
    # 获取一个客户端连接
    clientsocket,addr = serversocket.accept()
    print("连接地址:%s" % str(addr))
    try:
        t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程
        t.start()
        pass
    except Exception as identifier:
        print(identifier)
        pass
    pass
serversocket.close()
pass

class ServerThreading(threading.Thread):

# words = text2vec.load_lexicon()
def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"):
    threading.Thread.__init__(self)
    self._socket = clientsocket
    self._recvsize = recvsize
    self._encoding = encoding
    pass

def run(self):
    print("开启线程.....")
    try:
        #接受数据
        msg = ''
        while True:
            # 读取recvsize个字节
            rec = self._socket.recv(self._recvsize)
            # 解码
            msg += rec.decode(self._encoding)
            # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
            # 所以需要自定义协议标志数据接受完毕
            if msg.strip().endswith('over'):
                msg=msg[:-4]
                break
        # 发送数据
        self._socket.send("啦啦啦啦".encode(self._encoding))
        pass
    except Exception as identifier:
        self._socket.send("500".encode(self._encoding))
        print(identifier)
        pass
    finally:
        self._socket.close() 
    print("任务结束.....")
    pass

//启动服务
main()

Java客户端代码如下:

public  void test2() throws IOException {
    JSONObject jsonObject = new JSONObject();
    String content = "医疗保险缴费需要";
    jsonObject.put("content", content);
    String str = jsonObject.toJSONString();
    // 访问服务进程的套接字
    Socket socket = null;

// List questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);

    try {
        // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
        socket = new Socket("172.17.169.232",12345);
        // 获取输出流对象
        OutputStream os = socket.getOutputStream();
        PrintStream out = new PrintStream(os);
        // 发送内容
        out.print(str);
        // 告诉服务进程,内容发送完毕,可以开始处理
        out.print("over");
        // 获取服务进程的输入流
        InputStream is = socket.getInputStream();
        String text = IOUtils.toString(is);
        System.out.println(text);

    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        try {if(socket!=null) socket.close();} catch (IOException e) {}
        System.out.println("远程接口调用结束.");
    }
}

socket实现Python服务端确实比较简单,但是代码量比较大,没有前面Java直接部署训练好的模型简单。

(2)使用Python的Flask框架
Flask框架实现服务端,这个框架我是听我同学说的,因为他们公司就是使用这种方法部署深度学习模型的,不过我们项目当中没有用到,有兴趣的同学可以自己去了解一下这个Flask框架,这里不累述了。

总结

常用的方法基本上就上面这些了,以上方法各有各的优缺点,大家可以根据自己的项目需求自行选择合适的方法来部署训练好的深度学习模型,希望这篇博客可以帮到你们。

作者:中二小苇
来源:CSDN
原文:https://blog.csdn.net/u012350430/article/details/96272968
版权声明:本文为博主原创文章,转载请附上博文链接!

相关文章
|
2月前
|
人工智能 自然语言处理 TensorFlow
134_边缘推理:TensorFlow Lite - 优化移动端LLM部署技术详解与实战指南
在人工智能与移动计算深度融合的今天,将大语言模型(LLM)部署到移动端和边缘设备已成为行业发展的重要趋势。TensorFlow Lite作为专为移动和嵌入式设备优化的轻量级推理框架,为开发者提供了将复杂AI模型转换为高效、低功耗边缘计算解决方案的强大工具。随着移动设备硬件性能的不断提升和模型压缩技术的快速发展,2025年的移动端LLM部署已不再是遥远的愿景,而是正在成为现实的技术实践。
|
10月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习工具和框架详细指南:PyTorch、TensorFlow、Keras
在深度学习的世界中,PyTorch、TensorFlow和Keras是最受欢迎的工具和框架,它们为研究者和开发者提供了强大且易于使用的接口。在本文中,我们将深入探索这三个框架,涵盖如何用它们实现经典深度学习模型,并通过代码实例详细讲解这些工具的使用方法。
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
1031 5
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
455 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
649 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
438 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
2月前
|
JSON 网络协议 安全
【Java】(10)进程与线程的关系、Tread类;讲解基本线程安全、网络编程内容;JSON序列化与反序列化
几乎所有的操作系统都支持进程的概念,进程是处于运行过程中的程序,并且具有一定的独立功能,进程是系统进行资源分配和调度的一个独立单位一般而言,进程包含如下三个特征。独立性动态性并发性。
189 1
|
2月前
|
JSON 网络协议 安全
【Java基础】(1)进程与线程的关系、Tread类;讲解基本线程安全、网络编程内容;JSON序列化与反序列化
几乎所有的操作系统都支持进程的概念,进程是处于运行过程中的程序,并且具有一定的独立功能,进程是系统进行资源分配和调度的一个独立单位一般而言,进程包含如下三个特征。独立性动态性并发性。
213 1
|
3月前
|
数据采集 存储 弹性计算
高并发Java爬虫的瓶颈分析与动态线程优化方案
高并发Java爬虫的瓶颈分析与动态线程优化方案
Java 数据库 Spring
164 0