我用python训练好后以下面代码保存模型
output_graph_def=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('D:/data/netGraph/yzm.pb',mode='wb') as f:
f.write(output_graph_def.SerializeToString())
“output”是输出模块
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2,name="output")
然后以下面代码加载和运行模型
//加载
byte[] graph =Files.readAllBytes(path);
Graph g = new Graph();
g.importGraphDef(graph);
sess = new Session(g);
//计算,data是一个float数组
FloatBuffer buffer = FloatBuffer.wrap(data);
Tensor<Float> input = Tensor.create(new long[] { 1, 1200 }, buffer);
Tensor keepProb = Tensor.create(new long[] { 1 }, FloatBuffer.wrap(new float[] { 1.0f }));
Tensor result = sess.runner().feed("x", input).feed("keep_prob", keepProb).fetch("output").run().get(0);
现在问题是:
同样的数据,java计算出的结果和python计算出的结果不一样(差别很大)
虽然没有用过tensorflow的java api
但是提供一下思路,希望能够对楼主有帮助
① 通过tensorboard对比python和java的计算图
② 比较python和java的weights
③ 可以调查一下是否和java计算的精度问题有关,比如3/2=1,python2.x是这样的,但是Python3的话结果是1.5
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。