使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法

简介: 使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法写在前面最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法
写在前面
最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。关于这方面的内容网上并不是很多,我也是费了很多周折才完成任务的,这里来总结一下具体有哪些方法可以用,这些方法又有哪些缺陷,以供大家学习交流。

一、使用Java深度学习框架直接部署
(1)使用TensorFlow Java API部署TensorFlow模型
如果我们使用的是TensorFlow训练的模型,那么我们就可以直接使用Java中的TensorFlow API调用模型。这里需要注意的是我们得把训练好的模型保存为.pb格式的文件。具体代码如下:

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["quest_out"])

写入序列化的 PB 文件

with tf.gfile.FastGFile('/home/amax/zth/qa/new_model2_cpu.pb', mode='wb') as f:

f.write(constant_graph.SerializeToString())

然后我们需要在Java使用这个保存好的模型,在pom.xml中引入TensorFlow的依赖

        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>1.11.0</version>
    </dependency>

    <!-- https://mvnrepository.com/artifact/org.tensorflow/libtensorflow_jni_gpu -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>libtensorflow_jni_gpu</artifactId>
        <version>1.11.0</version>
    </dependency>

导包成功后,在Java中调用模型

graphDef = readAllBytes(new FileInputStream(new_model2_cpu.pb));
graph = new Graph();
graph.importGraphDef(graphDef);
session = new Session(graph);
Tensor result = session.runner()

            .feed("ori_quest_embedding", Tensor.create(wordVecInputSentence))//输入你自己的数据
            .feed("dropout", Tensor.create(1.0F))
            .fetch("quest_out") //和上面python保存模型时的output_node_names对应
            .run().get(0);

//这样就能得到模型的输出结果了

(2)使用Deeplearning4J Java API部署Keras模型
如果我们使用的是Keras训练的模型,那么你就可以选择Deeplearning4J 这个框架来调用模型。
第一步同样是使用Keras保存训练好的模型

filepath = "query_models"

checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True)

callback_lists = [checkpoint]

model.fit(x, y, epochs=1,validation_split=0.2,callbacks=callback_lists)

然后同样是Java项目中pom.xml导入Deeplearning4J 依赖

<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta2</version>    


<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-modelimport</artifactId>      
<version>1.0.0-beta2</version>    

库导入成功后,直接使用Java调用保存好的模型

MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(“query_models”);

这样模型就部署成功了,然后关于怎么使用模型这里就不多说了。
注意:这里需要注意的是Deeplearning4J 只支持部分深度学习模型,有些模型是不支持的,譬如我这里使用的GRU模型就不支持,运行上面代码会出现以下错误。明确指明不支持GRU模型

去Deeplearning4J 官网查询发现确实现在不支持GRU模型,以下是官网截图

所以如果你想使用Deeplearning4J 来部署训练好的模型,请先查看下是否支持你所使用的模型。

二、使用Python编写服务端
(1)使用socket实现进程间的通信
用python构建服务端,然后通过Java向服务端发送请求调用模型,第一种是使用socket实现进程中的通信,代码如下:

import socket
import sys
import threading
import json
import numpy as np
import jieba
import os
import numpy as np
import nltk
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM,GRU,TimeDistributed
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
from gensim.models.word2vec import Word2Vec
from keras.optimizers import Adam
from keras.models import load_model
import pickle

nn=network.getNetWork()

cnn = conv.main(False)

深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中

w2v_model = Word2Vec.load("word2vec.w2v").wv
UNK = pickle.load(open('unk.pkl','rb'))
model = load_model('query_models')
a = np.zeros((1, 223,200))
model.predict(a)

def test_init(string):

cut_list = jieba.lcut(string)
x_test = np.zeros((1, 223,200))
for i in range(223):
    x_test[0,i,:] = UNK
for i in range(len(cut_list)):
    if cut_list[i] in w2v_model:
        x_test[0,i,:] = w2v_model.wv[cut_list[i]]   
return x_test,len(cut_list)

string_list = list()
def query_complet(string):

x_test,length = test_init(string)
y = model.predict(x_test)
if length>8:
    return
word1 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[0][0]
word2 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[1][0]
if word1 == '?' or word1 == '?':
    string_list.append(string)
else:
    new_str = string+word1
    query_complet(new_str)

if word2 == '?' or word2 == '?':
    string_list.append(string)
else:
    new_str = string+word2
    query_complet(new_str)
    

def new_query_complet(string):

query_complet(string)
return string_list

def main():

# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind(("172.17.169.232",port))
# 设置监听最大连接数
serversocket.listen(5)
# 获取本地服务器的连接信息
myaddr = serversocket.getsockname()
print("服务器地址:%s"%str(myaddr))
# 循环等待接受客户端信息
while True:
    # 获取一个客户端连接
    clientsocket,addr = serversocket.accept()
    print("连接地址:%s" % str(addr))
    try:
        t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程
        t.start()
        pass
    except Exception as identifier:
        print(identifier)
        pass
    pass
serversocket.close()
pass

class ServerThreading(threading.Thread):

# words = text2vec.load_lexicon()
def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"):
    threading.Thread.__init__(self)
    self._socket = clientsocket
    self._recvsize = recvsize
    self._encoding = encoding
    pass

def run(self):
    print("开启线程.....")
    try:
        #接受数据
        msg = ''
        while True:
            # 读取recvsize个字节
            rec = self._socket.recv(self._recvsize)
            # 解码
            msg += rec.decode(self._encoding)
            # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
            # 所以需要自定义协议标志数据接受完毕
            if msg.strip().endswith('over'):
                msg=msg[:-4]
                break
        # 发送数据
        self._socket.send("啦啦啦啦".encode(self._encoding))
        pass
    except Exception as identifier:
        self._socket.send("500".encode(self._encoding))
        print(identifier)
        pass
    finally:
        self._socket.close() 
    print("任务结束.....")
    pass

//启动服务
main()

Java客户端代码如下:

public  void test2() throws IOException {
    JSONObject jsonObject = new JSONObject();
    String content = "医疗保险缴费需要";
    jsonObject.put("content", content);
    String str = jsonObject.toJSONString();
    // 访问服务进程的套接字
    Socket socket = null;

// List questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);

    try {
        // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
        socket = new Socket("172.17.169.232",12345);
        // 获取输出流对象
        OutputStream os = socket.getOutputStream();
        PrintStream out = new PrintStream(os);
        // 发送内容
        out.print(str);
        // 告诉服务进程,内容发送完毕,可以开始处理
        out.print("over");
        // 获取服务进程的输入流
        InputStream is = socket.getInputStream();
        String text = IOUtils.toString(is);
        System.out.println(text);

    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        try {if(socket!=null) socket.close();} catch (IOException e) {}
        System.out.println("远程接口调用结束.");
    }
}

socket实现Python服务端确实比较简单,但是代码量比较大,没有前面Java直接部署训练好的模型简单。

(2)使用Python的Flask框架
Flask框架实现服务端,这个框架我是听我同学说的,因为他们公司就是使用这种方法部署深度学习模型的,不过我们项目当中没有用到,有兴趣的同学可以自己去了解一下这个Flask框架,这里不累述了。

总结

常用的方法基本上就上面这些了,以上方法各有各的优缺点,大家可以根据自己的项目需求自行选择合适的方法来部署训练好的深度学习模型,希望这篇博客可以帮到你们。

作者:中二小苇
来源:CSDN
原文:https://blog.csdn.net/u012350430/article/details/96272968
版权声明:本文为博主原创文章,转载请附上博文链接!

相关文章
|
1月前
|
Java
Java语言实现字母大小写转换的方法
Java提供了多种灵活的方法来处理字符串中的字母大小写转换。根据具体需求,可以选择适合的方法来实现。在大多数情况下,使用 String类或 Character类的方法已经足够。但是,在需要更复杂的逻辑或处理非常规字符集时,可以通过字符流或手动遍历字符串来实现更精细的控制。
229 18
|
1月前
|
Java 编译器 Go
【Java】(5)方法的概念、方法的调用、方法重载、构造方法的创建
Java方法是语句的集合,它们在一起执行一个功能。方法是解决一类问题的步骤的有序组合方法包含于类或对象中方法在程序中被创建,在其他地方被引用方法的优点使程序变得更简短而清晰。有利于程序维护。可以提高程序开发的效率。提高了代码的重用性。方法的名字的第一个单词应以小写字母作为开头,后面的单词则用大写字母开头写,不使用连接符。例如:addPerson。这种就属于驼峰写法下划线可能出现在 JUnit 测试方法名称中用以分隔名称的逻辑组件。
198 4
|
1月前
|
机器学习/深度学习 人工智能 监控
Java与AI模型部署:构建企业级模型服务与生命周期管理平台
随着企业AI模型数量的快速增长,模型部署与生命周期管理成为确保AI应用稳定运行的关键。本文深入探讨如何使用Java生态构建一个企业级的模型服务平台,实现模型的版本控制、A/B测试、灰度发布、监控与回滚。通过集成Spring Boot、Kubernetes、MLflow和监控工具,我们将展示如何构建一个高可用、可扩展的模型服务架构,为大规模AI应用提供坚实的运维基础。
236 0
|
1月前
|
编解码 Java 开发者
Java String类的关键方法总结
以上总结了Java `String` 类最常见和重要功能性方法。每种操作都对应着日常编程任务,并且理解每种操作如何影响及处理 `Strings` 对于任何使用 Java 的开发者来说都至关重要。
277 5
|
1月前
|
人工智能 自然语言处理 TensorFlow
134_边缘推理:TensorFlow Lite - 优化移动端LLM部署技术详解与实战指南
在人工智能与移动计算深度融合的今天,将大语言模型(LLM)部署到移动端和边缘设备已成为行业发展的重要趋势。TensorFlow Lite作为专为移动和嵌入式设备优化的轻量级推理框架,为开发者提供了将复杂AI模型转换为高效、低功耗边缘计算解决方案的强大工具。随着移动设备硬件性能的不断提升和模型压缩技术的快速发展,2025年的移动端LLM部署已不再是遥远的愿景,而是正在成为现实的技术实践。
|
2月前
|
算法 安全 Java
除了类,Java中的接口和方法也可以使用泛型吗?
除了类,Java中的接口和方法也可以使用泛型吗?
141 11
|
2月前
|
Java 开发者
Java 函数式编程全解析:静态方法引用、实例方法引用、特定类型方法引用与构造器引用实战教程
本文介绍Java 8函数式编程中的四种方法引用:静态、实例、特定类型及构造器引用,通过简洁示例演示其用法,帮助开发者提升代码可读性与简洁性。
存储 jenkins 持续交付
493 2
|
3月前
|
算法 Java
Java语言实现链表反转的方法
这种反转方法不需要使用额外的存储空间,因此空间复杂度为,它只需要遍历一次链表,所以时间复杂度为,其中为链表的长度。这使得这种反转链表的方法既高效又实用。
384 0
|
11月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
949 55

热门文章

最新文章