PAI平台怎么保存Keras模型
可以这样,在训练阶段:from keras import backend as Kimport tensorflow as tfsess = tf.Session()K.set_session(sess)然后用keras定义网络、训练网络,最后保存的时候用tensorflow的方式保存saver = tf.train.Saver()saver.save(K.get_session(), model_path)在预测阶段,需要读取模型的时候可以这样:sess = tf.Session()model_path = xxxK.set_session(sess)然后重新定义一下网络结构model = Sequential()model.add(xxx)就可开始导入模型了saver = tf.train.Saver()saver.restore(K.get_session(), model_path)这步做完,模型参数就被导入好定义好的网络上了,即可以开始预测model.predict(xxx)
赞0
踩0