开发者社区> 问答> 正文

java载入tensorflow模型后计算结果与python的不一样

我用python训练好后以下面代码保存模型

output_graph_def=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,output_node_names=['output'])
    
with tf.gfile.FastGFile('D:/data/netGraph/yzm.pb',mode='wb') as f:
        f.write(output_graph_def.SerializeToString())

“output”是输出模块

y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2,name="output")

然后以下面代码加载和运行模型

//加载
byte[] graph =Files.readAllBytes(path);
Graph g = new Graph();
g.importGraphDef(graph);
sess = new Session(g);
//计算,data是一个float数组
FloatBuffer buffer = FloatBuffer.wrap(data);
Tensor<Float> input = Tensor.create(new long[] { 1, 1200 }, buffer);
Tensor keepProb = Tensor.create(new long[] { 1 }, FloatBuffer.wrap(new float[] { 1.0f }));
Tensor result = sess.runner().feed("x", input).feed("keep_prob", keepProb).fetch("output").run().get(0);

现在问题是:
同样的数据,java计算出的结果和python计算出的结果不一样(差别很大)

展开
收起
只是一条狗 2017-12-18 09:37:08 7186 0
2 条回答
写回答
取消 提交回答
  • 一个debug的思路吧……
    写一个y=x的模型看看同样的代码和保存的模型加载出来哪边的计算结果不一样

    2019-07-17 21:49:28
    赞同 展开评论 打赏
  • 个人博客www.soaringroad.com构建中,欢迎光临。

    虽然没有用过tensorflow的java api
    但是提供一下思路,希望能够对楼主有帮助
    ① 通过tensorboard对比python和java的计算图
    ② 比较python和java的weights
    ③ 可以调查一下是否和java计算的精度问题有关,比如3/2=1,python2.x是这样的,但是Python3的话结果是1.5

    2019-07-17 21:49:28
    赞同 展开评论 打赏
问答排行榜
最热
最新

相关电子书

更多
Spring Cloud Alibaba - 重新定义 Java Cloud-Native 立即下载
The Reactive Cloud Native Arch 立即下载
JAVA开发手册1.5.0 立即下载