开发者社区> 问答> 正文

PAI平台怎么保存Keras模型

Keras里的model.save()方法因为底层调用了h5py的I/O方法,导致无法在PAI平台上保存模型,请问有什么解决方法吗?

展开
收起
tigerrex 2017-09-27 15:54:39 3690 0
3 条回答
写回答
取消 提交回答
  • 可以这样,在训练阶段:
    from keras import backend as K
    import tensorflow as tf

    sess = tf.Session()
    K.set_session(sess)

    然后用keras定义网络、训练网络,最后保存的时候用tensorflow的方式保存

    saver = tf.train.Saver()
    saver.save(K.get_session(), model_path)

    在预测阶段,需要读取模型的时候可以这样:
    sess = tf.Session()
    model_path = xxx
    K.set_session(sess)

    然后重新定义一下网络结构
    model = Sequential()
    model.add(xxx)

    就可开始导入模型了

    saver = tf.train.Saver()
    saver.restore(K.get_session(), model_path)

    这步做完,模型参数就被导入好定义好的网络上了,即可以开始预测

    model.predict(xxx)

    2019-07-17 21:37:05
    赞同 展开评论 打赏
  • 同问

    2019-07-17 21:37:05
    赞同 展开评论 打赏
  • print 出来,然后手动下载?。。

    2019-07-17 21:37:05
    赞同 展开评论 打赏
问答排行榜
最热
最新

相关电子书

更多
微博机器学习平台架构和实践 立即下载
机器学习及人机交互实战 立即下载
大数据与机器学习支撑的个性化大屏 立即下载