使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法

简介: 使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法写在前面最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。

使用Java部署TensorFlow和Keras训练好的深度学习模型的几种方法
写在前面
最近在一个自然语言处理方面的项目,选用的深度学习模型有两个,一个是CNN+LSTM模型,一个是GRU模型,这两个模型在GPU服务器上训练好了,然后需要使用Java调用这两个模型,CNN+LSTM使用TensorFlow写的,GRU是用Keras写的,所以需要用Java部署TensorFlow和Keras训练好的深度学习模型。关于这方面的内容网上并不是很多,我也是费了很多周折才完成任务的,这里来总结一下具体有哪些方法可以用,这些方法又有哪些缺陷,以供大家学习交流。

一、使用Java深度学习框架直接部署
(1)使用TensorFlow Java API部署TensorFlow模型
如果我们使用的是TensorFlow训练的模型,那么我们就可以直接使用Java中的TensorFlow API调用模型。这里需要注意的是我们得把训练好的模型保存为.pb格式的文件。具体代码如下:

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=["quest_out"])

写入序列化的 PB 文件

with tf.gfile.FastGFile('/home/amax/zth/qa/new_model2_cpu.pb', mode='wb') as f:

f.write(constant_graph.SerializeToString())

然后我们需要在Java使用这个保存好的模型,在pom.xml中引入TensorFlow的依赖

        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow</artifactId>
        <version>1.11.0</version>
    </dependency>

    <!-- https://mvnrepository.com/artifact/org.tensorflow/libtensorflow_jni_gpu -->
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>libtensorflow_jni_gpu</artifactId>
        <version>1.11.0</version>
    </dependency>

导包成功后,在Java中调用模型

graphDef = readAllBytes(new FileInputStream(new_model2_cpu.pb));
graph = new Graph();
graph.importGraphDef(graphDef);
session = new Session(graph);
Tensor result = session.runner()

            .feed("ori_quest_embedding", Tensor.create(wordVecInputSentence))//输入你自己的数据
            .feed("dropout", Tensor.create(1.0F))
            .fetch("quest_out") //和上面python保存模型时的output_node_names对应
            .run().get(0);

//这样就能得到模型的输出结果了

(2)使用Deeplearning4J Java API部署Keras模型
如果我们使用的是Keras训练的模型,那么你就可以选择Deeplearning4J 这个框架来调用模型。
第一步同样是使用Keras保存训练好的模型

filepath = "query_models"

checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True)

callback_lists = [checkpoint]

model.fit(x, y, epochs=1,validation_split=0.2,callbacks=callback_lists)

然后同样是Java项目中pom.xml导入Deeplearning4J 依赖

<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta2</version>    


<groupId>org.deeplearning4j</groupId>      
<artifactId>deeplearning4j-modelimport</artifactId>      
<version>1.0.0-beta2</version>    

库导入成功后,直接使用Java调用保存好的模型

MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(“query_models”);

这样模型就部署成功了,然后关于怎么使用模型这里就不多说了。
注意:这里需要注意的是Deeplearning4J 只支持部分深度学习模型,有些模型是不支持的,譬如我这里使用的GRU模型就不支持,运行上面代码会出现以下错误。明确指明不支持GRU模型

去Deeplearning4J 官网查询发现确实现在不支持GRU模型,以下是官网截图

所以如果你想使用Deeplearning4J 来部署训练好的模型,请先查看下是否支持你所使用的模型。

二、使用Python编写服务端
(1)使用socket实现进程间的通信
用python构建服务端,然后通过Java向服务端发送请求调用模型,第一种是使用socket实现进程中的通信,代码如下:

import socket
import sys
import threading
import json
import numpy as np
import jieba
import os
import numpy as np
import nltk
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM,GRU,TimeDistributed
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
from gensim.models.word2vec import Word2Vec
from keras.optimizers import Adam
from keras.models import load_model
import pickle

nn=network.getNetWork()

cnn = conv.main(False)

深度学习训练的神经网络,使用TensorFlow训练的神经网络模型,保存在文件中

w2v_model = Word2Vec.load("word2vec.w2v").wv
UNK = pickle.load(open('unk.pkl','rb'))
model = load_model('query_models')
a = np.zeros((1, 223,200))
model.predict(a)

def test_init(string):

cut_list = jieba.lcut(string)
x_test = np.zeros((1, 223,200))
for i in range(223):
    x_test[0,i,:] = UNK
for i in range(len(cut_list)):
    if cut_list[i] in w2v_model:
        x_test[0,i,:] = w2v_model.wv[cut_list[i]]   
return x_test,len(cut_list)

string_list = list()
def query_complet(string):

x_test,length = test_init(string)
y = model.predict(x_test)
if length>8:
    return
word1 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[0][0]
word2 = w2v_model.wv.most_similar(positive=[y[0][length-1]], topn=2)[1][0]
if word1 == '?' or word1 == '?':
    string_list.append(string)
else:
    new_str = string+word1
    query_complet(new_str)

if word2 == '?' or word2 == '?':
    string_list.append(string)
else:
    new_str = string+word2
    query_complet(new_str)
    

def new_query_complet(string):

query_complet(string)
return string_list

def main():

# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind(("172.17.169.232",port))
# 设置监听最大连接数
serversocket.listen(5)
# 获取本地服务器的连接信息
myaddr = serversocket.getsockname()
print("服务器地址:%s"%str(myaddr))
# 循环等待接受客户端信息
while True:
    # 获取一个客户端连接
    clientsocket,addr = serversocket.accept()
    print("连接地址:%s" % str(addr))
    try:
        t = ServerThreading(clientsocket)#为每一个请求开启一个处理线程
        t.start()
        pass
    except Exception as identifier:
        print(identifier)
        pass
    pass
serversocket.close()
pass

class ServerThreading(threading.Thread):

# words = text2vec.load_lexicon()
def __init__(self,clientsocket,recvsize=1024*1024,encoding="utf-8"):
    threading.Thread.__init__(self)
    self._socket = clientsocket
    self._recvsize = recvsize
    self._encoding = encoding
    pass

def run(self):
    print("开启线程.....")
    try:
        #接受数据
        msg = ''
        while True:
            # 读取recvsize个字节
            rec = self._socket.recv(self._recvsize)
            # 解码
            msg += rec.decode(self._encoding)
            # 文本接受是否完毕,因为python socket不能自己判断接收数据是否完毕,
            # 所以需要自定义协议标志数据接受完毕
            if msg.strip().endswith('over'):
                msg=msg[:-4]
                break
        # 发送数据
        self._socket.send("啦啦啦啦".encode(self._encoding))
        pass
    except Exception as identifier:
        self._socket.send("500".encode(self._encoding))
        print(identifier)
        pass
    finally:
        self._socket.close() 
    print("任务结束.....")
    pass

//启动服务
main()

Java客户端代码如下:

public  void test2() throws IOException {
    JSONObject jsonObject = new JSONObject();
    String content = "医疗保险缴费需要";
    jsonObject.put("content", content);
    String str = jsonObject.toJSONString();
    // 访问服务进程的套接字
    Socket socket = null;

// List questions = new ArrayList<>();
// log.info("调用远程接口:host=>"+HOST+",port=>"+PORT);

    try {
        // 初始化套接字,设置访问服务的主机和进程端口号,HOST是访问python进程的主机名称,可以是IP地址或者域名,PORT是python进程绑定的端口号
        socket = new Socket("172.17.169.232",12345);
        // 获取输出流对象
        OutputStream os = socket.getOutputStream();
        PrintStream out = new PrintStream(os);
        // 发送内容
        out.print(str);
        // 告诉服务进程,内容发送完毕,可以开始处理
        out.print("over");
        // 获取服务进程的输入流
        InputStream is = socket.getInputStream();
        String text = IOUtils.toString(is);
        System.out.println(text);

    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        try {if(socket!=null) socket.close();} catch (IOException e) {}
        System.out.println("远程接口调用结束.");
    }
}

socket实现Python服务端确实比较简单,但是代码量比较大,没有前面Java直接部署训练好的模型简单。

(2)使用Python的Flask框架
Flask框架实现服务端,这个框架我是听我同学说的,因为他们公司就是使用这种方法部署深度学习模型的,不过我们项目当中没有用到,有兴趣的同学可以自己去了解一下这个Flask框架,这里不累述了。

总结

常用的方法基本上就上面这些了,以上方法各有各的优缺点,大家可以根据自己的项目需求自行选择合适的方法来部署训练好的深度学习模型,希望这篇博客可以帮到你们。

作者:中二小苇
来源:CSDN
原文:https://blog.csdn.net/u012350430/article/details/96272968
版权声明:本文为博主原创文章,转载请附上博文链接!

相关文章
|
15天前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
60 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
2月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
57 0
|
2月前
|
C# 开发者 前端开发
揭秘混合开发新趋势:Uno Platform携手Blazor,教你一步到位实现跨平台应用,代码复用不再是梦!
【8月更文挑战第31天】随着前端技术的发展,混合开发日益受到开发者青睐。本文详述了如何结合.NET生态下的两大框架——Uno Platform与Blazor,进行高效混合开发。Uno Platform基于WebAssembly和WebGL技术,支持跨平台应用构建;Blazor则让C#成为可能的前端开发语言,实现了客户端与服务器端逻辑共享。二者结合不仅提升了代码复用率与跨平台能力,还简化了项目维护并增强了Web应用性能。文中提供了从环境搭建到示例代码的具体步骤,并展示了如何创建一个简单的计数器应用,帮助读者快速上手混合开发。
45 0
|
2月前
|
开发者 算法 虚拟化
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
40 0
|
2月前
|
前端开发 开发者 设计模式
揭秘Uno Platform状态管理之道:INotifyPropertyChanged、依赖注入、MVVM大对决,帮你找到最佳策略!
【8月更文挑战第31天】本文对比分析了 Uno Platform 中的关键状态管理策略,包括内置的 INotifyPropertyChanged、依赖注入及 MVVM 框架。INotifyPropertyChanged 方案简单易用,适合小型项目;依赖注入则更灵活,支持状态共享与持久化,适用于复杂场景;MVVM 框架通过分离视图、视图模型和模型,使状态管理更清晰,适合大型项目。开发者可根据项目需求和技术栈选择合适的状态管理方案,以实现高效管理。
32 0
|
4月前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
163 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
4月前
|
机器学习/深度学习 人工智能 算法
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
乐器识别系统。使用Python为主要编程语言,基于人工智能框架库TensorFlow搭建ResNet50卷积神经网络算法,通过对30种乐器('迪吉里杜管', '铃鼓', '木琴', '手风琴', '阿尔卑斯号角', '风笛', '班卓琴', '邦戈鼓', '卡萨巴', '响板', '单簧管', '古钢琴', '手风琴(六角形)', '鼓', '扬琴', '长笛', '刮瓜', '吉他', '口琴', '竖琴', '沙槌', '陶笛', '钢琴', '萨克斯管', '锡塔尔琴', '钢鼓', '长号', '小号', '大号', '小提琴')的图像数据集进行训练,得到一个训练精度较高的模型,并将其
56 0
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
|
16天前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
38 0
|
25天前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
45 0
|
2月前
|
持续交付 测试技术 jenkins
JSF 邂逅持续集成,紧跟技术热点潮流,开启高效开发之旅,引发开发者强烈情感共鸣
【8月更文挑战第31天】在快速发展的软件开发领域,JavaServer Faces(JSF)这一强大的Java Web应用框架与持续集成(CI)结合,可显著提升开发效率及软件质量。持续集成通过频繁的代码集成及自动化构建测试,实现快速反馈、高质量代码、加强团队协作及简化部署流程。以Jenkins为例,配合Maven或Gradle,可轻松搭建JSF项目的CI环境,通过JUnit和Selenium编写自动化测试,确保每次构建的稳定性和正确性。
44 0

热门文章

最新文章

下一篇
无影云桌面