使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法

简介: 使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法写在前面最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法
写在前面
最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。关于这方面的内容网上并不是很多,我也是费了很多周折才完成任务的,这里来总结一下具体有哪些方法可以用,这些方法又有哪些缺陷,以供大家学习交流。

一、使用Java深度学习框架直接部署
(1)使用TensorFlow Java API部署TensorFlow模型
如果我们使用的是TensorFlow训练的模型,那么我们就可以直接使用Java中的TensorFlow API调用模型。这里需要注意的是我们得把训练好的模型保存为.pb格式的文件。具体代码如下:

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["quest_out"])

写入序列化的 PB 文件

with tf.gfile.FastGFile('/home/amax/zth/qa/new_model2_cpu.pb', mode='wb') as f:

f.write(constant_graph.SerializeToString())

然后我们需要在Java使用这个保存好的模型,在pom.xml中引入TensorFlow的依赖

        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>1.11.0</version>
    </dependency>

    <!-- https://mvnrepository.com/artifact/org.tensorflow/libtensorflow_jni_gpu -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>libtensorflow_jni_gpu</artifactId>
        <version>1.11.0</version>
    </dependency>

导包成功后,在Java中调用模型

graphDef = readAllBytes(new FileInputStream(new_model2_cpu.pb));
graph = new Graph();
graph.importGraphDef(graphDef);
session = new Session(graph);
Tensor result = session.runner()

            .feed("ori_quest_embedding", Tensor.create(wordVecInputSentence))//输入你自己的数据
            .feed("dropout", Tensor.create(1.0F))
            .fetch("quest_out") //和上面python保存模型时的output_node_names对应
            .run().get(0);

//这样就能得到模型的输出结果了

(2)使用Deeplearning4J Java API部署Keras模型
如果我们使用的是Keras训练的模型,那么你就可以选择Deeplearning4J 这个框架来调用模型。
第一步同样是使用Keras保存训练好的模型

filepath = "query_models"

checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True)

callback_lists = [checkpoint]

model.fit(x, y, epochs=1,validation_split=0.2,callbacks=callback_lists)

然后同样是Java项目中pom.xml导入Deeplearning4J 依赖

<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta2</version>    


<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-modelimport</artifactId>      
<version>1.0.0-beta2</version>    

库导入成功后,直接使用Java调用保存好的模型

MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(“query_models”);

这样模型就部署成功了,然后关于怎么使用模型这里就不多说了。
注意:这里需要注意的是Deeplearning4J 只支持部分深度学习模型,有些模型是不支持的,譬如我这里使用的GRU模型就不支持,运行上面代码会出现以下错误。明确指明不支持GRU模型

去Deeplearning4J 官网查询发现确实现在不支持GRU模型,以下是官网截图

所以如果你想使用Deeplearning4J 来部署训练好的模型,请先查看下是否支持你所使用的模型。

二、使用Python编写服务端
(1)使用socket实现进程间的通信
用python构建服务端,然后通过Java向服务端发送请求调用模型,第一种是使用socket实现进程中的通信,代码如下:

import socket
import sys
import threading
import json
import numpy as np
import jieba
import os
import numpy as np
import nltk
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM,GRU,TimeDistributed
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
from gensim.models.word2vec import Word2Vec
from keras.optimizers import Adam
from keras.models import load_model
import pickle

nn=network.getNetWork()

cnn = conv.main(False)

深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中

w2v_model = Word2Vec.load("word2vec.w2v").wv
UNK = pickle.load(open('unk.pkl','rb'))
model = load_model('query_models')
a = np.zeros((1, 223,200))
model.predict(a)

def test_init(string):

cut_list = jieba.lcut(string)
x_test = np.zeros((1, 223,200))
for i in range(223):
    x_test[0,i,:] = UNK
for i in range(len(cut_list)):
    if cut_list[i] in w2v_model:
        x_test[0,i,:] = w2v_model.wv[cut_list[i]]   
return x_test,len(cut_list)

string_list = list()
def query_complet(string):

x_test,length = test_init(string)
y = model.predict(x_test)
if length>8:
    return
word1 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[0][0]
word2 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[1][0]
if word1 == '?' or word1 == '?':
    string_list.append(string)
else:
    new_str = string+word1
    query_complet(new_str)

if word2 == '?' or word2 == '?':
    string_list.append(string)
else:
    new_str = string+word2
    query_complet(new_str)
    

def new_query_complet(string):

query_complet(string)
return string_list

def main():

# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind(("172.17.169.232",port))
# 设置监听最大连接数
serversocket.listen(5)
# 获取本地服务器的连接信息
myaddr = serversocket.getsockname()
print("服务器地址:%s"%str(myaddr))
# 循环等待接受客户端信息
while True:
    # 获取一个客户端连接
    clientsocket,addr = serversocket.accept()
    print("连接地址:%s" % str(addr))
    try:
        t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程
        t.start()
        pass
    except Exception as identifier:
        print(identifier)
        pass
    pass
serversocket.close()
pass

class ServerThreading(threading.Thread):

# words = text2vec.load_lexicon()
def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"):
    threading.Thread.__init__(self)
    self._socket = clientsocket
    self._recvsize = recvsize
    self._encoding = encoding
    pass

def run(self):
    print("开启线程.....")
    try:
        #接受数据
        msg = ''
        while True:
            # 读取recvsize个字节
            rec = self._socket.recv(self._recvsize)
            # 解码
            msg += rec.decode(self._encoding)
            # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
            # 所以需要自定义协议标志数据接受完毕
            if msg.strip().endswith('over'):
                msg=msg[:-4]
                break
        # 发送数据
        self._socket.send("啦啦啦啦".encode(self._encoding))
        pass
    except Exception as identifier:
        self._socket.send("500".encode(self._encoding))
        print(identifier)
        pass
    finally:
        self._socket.close() 
    print("任务结束.....")
    pass

//启动服务
main()

Java客户端代码如下:

public  void test2() throws IOException {
    JSONObject jsonObject = new JSONObject();
    String content = "医疗保险缴费需要";
    jsonObject.put("content", content);
    String str = jsonObject.toJSONString();
    // 访问服务进程的套接字
    Socket socket = null;

// List questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);

    try {
        // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
        socket = new Socket("172.17.169.232",12345);
        // 获取输出流对象
        OutputStream os = socket.getOutputStream();
        PrintStream out = new PrintStream(os);
        // 发送内容
        out.print(str);
        // 告诉服务进程,内容发送完毕,可以开始处理
        out.print("over");
        // 获取服务进程的输入流
        InputStream is = socket.getInputStream();
        String text = IOUtils.toString(is);
        System.out.println(text);

    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        try {if(socket!=null) socket.close();} catch (IOException e) {}
        System.out.println("远程接口调用结束.");
    }
}

socket实现Python服务端确实比较简单,但是代码量比较大,没有前面Java直接部署训练好的模型简单。

(2)使用Python的Flask框架
Flask框架实现服务端,这个框架我是听我同学说的,因为他们公司就是使用这种方法部署深度学习模型的,不过我们项目当中没有用到,有兴趣的同学可以自己去了解一下这个Flask框架,这里不累述了。

总结

常用的方法基本上就上面这些了,以上方法各有各的优缺点,大家可以根据自己的项目需求自行选择合适的方法来部署训练好的深度学习模型,希望这篇博客可以帮到你们。

作者:中二小苇
来源:CSDN
原文:https://blog.csdn.net/u012350430/article/details/96272968
版权声明:本文为博主原创文章,转载请附上博文链接!

相关文章
|
10天前
|
机器学习/深度学习 数据可视化 TensorFlow
使用Python实现深度学习模型的分布式训练
使用Python实现深度学习模型的分布式训练
127 73
|
21天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
112 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
29天前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
52 8
|
29天前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
70 5
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
84 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
1月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
93 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
1月前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
93 0
|
18天前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
79 5
|
10天前
|
机器学习/深度学习 网络架构 计算机视觉
深度学习在图像识别中的应用与挑战
【10月更文挑战第21天】 本文探讨了深度学习技术在图像识别领域的应用,并分析了当前面临的主要挑战。通过研究卷积神经网络(CNN)的结构和原理,本文展示了深度学习如何提高图像识别的准确性和效率。同时,本文也讨论了数据不平衡、过拟合、计算资源限制等问题,并提出了相应的解决策略。
58 19
|
10天前
|
机器学习/深度学习 传感器 人工智能
探索深度学习在图像识别中的应用与挑战
【10月更文挑战第21天】 本文深入探讨了深度学习技术在图像识别领域的应用,并分析了当前面临的主要挑战。通过介绍卷积神经网络(CNN)的基本原理和架构设计,阐述了深度学习如何有效地从图像数据中提取特征,并在多个领域实现突破性进展。同时,文章也指出了训练深度模型时常见的过拟合问题、数据不平衡以及计算资源需求高等挑战,并提出了相应的解决策略。
54 7