我用python训练好后以下面代码保存模型
output_graph_def=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('D:/data/netGraph/yzm.pb',mode='wb') as f:
f.write(output_graph_def.SerializeToString())
“output”是输出模块
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2,name="output")
然后以下面代码加载和运行模型
//加载
byte[] graph =Files.readAllBytes(path);
Graph g = new Graph();
g.importGraphDef(graph);
sess = new Session(g);
//计算,data是一个float数组
FloatBuffer buffer = FloatBuffer.wrap(data);
Tensor<Float> input = Tensor.create(new long[] { 1, 1200 }, buffer);
Tensor keepProb = Tensor.create(new long[] { 1 }, FloatBuffer.wrap(new float[] { 1.0f }));
Tensor result = sess.runner().feed("x", input).feed("keep_prob", keepProb).fetch("output").run().get(0);
现在问题是:
同样的数据,java计算出的结果和python计算出的结果不一样(差别很大)
了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目