开发者社区> 问答> 正文

PAI平台怎么保存Keras模型

tigerrex 2017-09-27 15:54:39 1608

Keras里的model.save()方法因为底层调用了h5py的I/O方法,导致无法在PAI平台上保存模型,请问有什么解决方法吗?

算法框架/工具
分享到
取消 提交回答
全部回答(3)
  • yaphel
    2019-07-17 21:37:05

    可以这样,在训练阶段:
    from keras import backend as K
    import tensorflow as tf

    sess = tf.Session()
    K.set_session(sess)

    然后用keras定义网络、训练网络,最后保存的时候用tensorflow的方式保存

    saver = tf.train.Saver()
    saver.save(K.get_session(), model_path)

    在预测阶段,需要读取模型的时候可以这样:
    sess = tf.Session()
    model_path = xxx
    K.set_session(sess)

    然后重新定义一下网络结构
    model = Sequential()
    model.add(xxx)

    就可开始导入模型了

    saver = tf.train.Saver()
    saver.restore(K.get_session(), model_path)

    这步做完,模型参数就被导入好定义好的网络上了,即可以开始预测

    model.predict(xxx)

    0 0
  • eric3232332
    2019-07-17 21:37:05

    同问

    0 0
  • wayne2019
    2019-07-17 21:37:05

    print 出来,然后手动下载?。。

    0 0
添加回答
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

推荐文章
相似问题
推荐课程