TensorFlow训练模型需要经过多个epoch,但是并不是epoch越多越好,很有可能训练一半的epoch时,模型的效果开始下降,这是我们需要停止训练,及时的保存模型,为了完成这种需求我们可以自定义回调函数,自动检测模型的损失,只要达到一定阈值我们手动让模型停止训练
完整代码
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2022/1/4
* 时间: 10:32
* 描述:
"""
import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *
def get_model():
inputs = Input(shape=(784,))
outputs = Dense(1)(inputs)
model = Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss='mean_squared_error',
metrics=['mean_absolute_error']
)
return model
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
class CustomEarlyStoppingAtMinLoss(keras.callbacks.Callback):
def __init__(self, patience=0):
super(CustomEarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
self.best_weights = None
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf
def on_train_begin(self, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=10,
verbose=1,
validation_split=0.5,
callbacks=[CustomEarlyStoppingAtMinLoss()],
)