项目介绍
本项目采用卷积神经网络,为了保证整个项目的完整性,在训练过程中不仅要显示损失或者准确率,而且在训练完成后需要保存得到的模型,然后调用摄像头来实时预测新的图像,新图像可以是数据集中的,也可以是自己手写的。通过实现整个过程,将OpenCV、神经网络以及TensorFlow结合起来学习,项目流程图如图所示。
项目流程图
在5.4节中,通过TensorFlow框架实现了一个类似于LeNet-5的神经网络,来解决MNIST数据集上的手写数字识别问题,本节训练过程依然使用该网络,并且在最后训练出模型,模型文件以变量的形式存储参数,该变量需要在代码中初始化。在训练过程中,将更新的参数存储到变量中,使用tf.train.Saver()对象将所有的变量添加到Graph中。
保存模型的函数为:
save_path = saver.save(sess, model_path)
如果每隔一定的迭代步数就保存一次模型,就把迭代步数作为参数传进去:
save_path = saver.save(sess, model_path, global_step=step,write_meta_graph=False)
在模型保存之后,调用该模型可以完成新数据的分类预测,模型在保存后会生成4个文件,TensorFlow模型如图所示。
TensorFlow模型
其中,model.meta是训练过程中保存的元数据;model.data-00000-of-00001和model.index是检查点文件,存储着训练过程中保存的模型;checkpoint是记录文件,保存最新检查点文件的记录。