TensorFlow2.0(12):模型保存与序列化

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: TensorFlow2.0(12):模型保存与序列化

模型训练好之后,我们就要想办法将其持久化保存下来,不然关机或者程序退出后模型就不复存在了。本文介绍两种持久化保存模型的方法:


在介绍这两种方法之前,我们得先创建并训练好一个模型,还是以mnist手写数字识别数据集训练模型为例:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, Sequential
model = Sequential([  # 创建模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,  # 进行简单的1次迭代训练
                    batch_size=64,
                    epochs=1)


Train on 60000 samples
60000/60000 [==============================] - 3s 46us/sample - loss: 2.3700


方法一:model.save()


通过模型自带的save()方法可以将模型保存到一个指定文件中,保存的内容包括:


  • 模型的结构
  • 模型的权重参数
  • 通过compile()方法配置的模型训练参数
  • 优化器及其状态


model.save('mymodels/mnist.h5')


使用save()方法保存后,在mymodels目录下就会有一个mnist.h5文件。需要使用模型时,通过keras.models.load_model()方法从文件中再次加载即可。


new_model = keras.models.load_model('mymodels/mnist.h5')


WARNING:tensorflow:Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.


新加载出来的new_model在结构、功能、参数各方面与model是一样的。


通过save()方法,也可以将模型保存为SavedModel 格式。SavedModel格式是TensorFlow所特有的一种序列化文件格式,其他编程语言实现的TensorFlow中同样支持:


model.save('mymodels/mnist_model', save_format='tf')  # 将模型保存为SaveModel格式


WARNING:tensorflow:From /home/chb/anaconda3/envs/study_python/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: mymodels/mnist_model/assets


new_model = keras.models.load_model('mymodels/mnist_model')  # 加载模型


print(keras.models.__dir__())


['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '_sys', 'Sequential', 'Model', 'clone_model', 'model_from_config', 'model_from_json', 'model_from_yaml', 'load_model', 'save_model']


方法二:model.save_weights()


save()方法会保留模型的所有信息,但有时候,我们仅对部分信息感兴趣,例如仅对模型的权重参数感兴趣,那么就可以通过save_weights()方法进行保存。


model.save_weights('mymodels/mnits_weights')  # 保存模型权重信息
new_model = Sequential([  # 创建新的模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
new_model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
new_model.load_weights('mymodels/mnits_weights')  # 将保存好的权重信息加载的新的模型中


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f49c42b87d0>


注:本系列所有博客将持续更新并发布在github上,您可以通过github下载本系列所有文章笔记文件。

相关文章
|
3月前
|
机器学习/深度学习 算法 TensorFlow
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
63 0
文本分类识别Python+卷积神经网络算法+TensorFlow模型训练+Django可视化界面
|
3月前
|
机器学习/深度学习 监控 Python
tensorflow2.x多层感知机模型参数量和计算量的统计
tensorflow2.x多层感知机模型参数量和计算量的统计
|
6月前
|
TensorFlow 算法框架/工具
【tensorflow】TF1.x保存与读取.pb模型写法介绍
由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【tensorflow】连续输入的线性回归模型训练代码
  get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。
|
6月前
|
XML 存储 JSON
【面试题精讲】序列化协议对应于 TCP/IP 4 层模型的哪一层?
【面试题精讲】序列化协议对应于 TCP/IP 4 层模型的哪一层?
|
6月前
|
机器学习/深度学习 算法 TensorFlow
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
树叶识别系统python+Django网页界面+TensorFlow+算法模型+数据集+图像识别分类
134 1
|
6月前
|
机器学习/深度学习 移动开发 算法
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
92 0
动物识别系统python+Django网页界面+TensorFlow算法模型+数据集训练
|
6月前
|
机器学习/深度学习 算法 TensorFlow
交通标志识别系统python+TensorFlow+算法模型+Django网页+数据集
交通标志识别系统python+TensorFlow+算法模型+Django网页+数据集
62 0
|
3月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
194 0
|
9天前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。