TensorFlow自己定义EarlyStop回调函数通过监测loss指标

简介: TensorFlow训练模型需要经过多个epoch,但是并不是epoch越多越好,很有可能训练一半的epoch时,模型的效果开始下降,这是我们需要停止训练,及时的保存模型,为了完成这种需求我们可以自定义回调函数,自动检测模型的损失,只要达到一定阈值我们手动让模型停止训练

TensorFlow训练模型需要经过多个epoch,但是并不是epoch越多越好,很有可能训练一半的epoch时,模型的效果开始下降,这是我们需要停止训练,及时的保存模型,为了完成这种需求我们可以自定义回调函数,自动检测模型的损失,只要达到一定阈值我们手动让模型停止训练


完整代码


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/4

* 时间: 10:32

* 描述:

"""

import numpy as np

import tensorflow as tf

from keras import Model

from tensorflow import keras

from tensorflow.keras.layers import *



def get_model():

   inputs = Input(shape=(784,))

   outputs = Dense(1)(inputs)

   model = Model(inputs, outputs)

   model.compile(

       optimizer=keras.optimizers.RMSprop(learning_rate=0.1),

       loss='mean_squared_error',

       metrics=['mean_absolute_error']

   )

   return model



(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784).astype('float32') / 255.0

x_test = x_test.reshape(-1, 784).astype('float32') / 255.0


x_train = x_train[:1000]

y_train = y_train[:1000]

x_test = x_test[:1000]

y_test = y_test[:1000]



class CustomEarlyStoppingAtMinLoss(keras.callbacks.Callback):

   def __init__(self, patience=0):

       super(CustomEarlyStoppingAtMinLoss, self).__init__()

       self.patience = patience

       self.best_weights = None

       self.wait = 0

       self.stopped_epoch = 0

       self.best = np.Inf


   def on_train_begin(self, logs=None):

       pass


   def on_epoch_end(self, epoch, logs=None):

       current = logs.get("loss")

       if np.less(current, self.best):

           self.best = current

           self.wait = 0

           self.best_weights = self.model.get_weights()

       else:

           self.wait += 1

           if self.wait >= self.patience:

               self.stopped_epoch = epoch

               self.model.stop_training = True

               print("Restoring model weights from the end of the best epoch.")

               self.model.set_weights(self.best_weights)


   def on_train_end(self, logs=None):

       if self.stopped_epoch > 0:

           print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))



model = get_model()

model.fit(

   x_train,

   y_train,

   batch_size=128,

   epochs=10,

   verbose=1,

   validation_split=0.5,

   callbacks=[CustomEarlyStoppingAtMinLoss()],

)

目录
相关文章
|
5月前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
5月前
|
数据可视化
R语言建立和可视化混合效应模型mixed effect model
R语言建立和可视化混合效应模型mixed effect model
|
5月前
|
机器学习/深度学习 数据可视化 TensorFlow
用TensorBoard可视化tensorflow神经网络模型结构与训练过程的方法
用TensorBoard可视化tensorflow神经网络模型结构与训练过程的方法
303 1
|
数据处理
超参数调整实战:scikit-learn配合XGBoost的竞赛top20策略
超参数调整实战:scikit-learn配合XGBoost的竞赛top20策略
290 1
超参数调整实战:scikit-learn配合XGBoost的竞赛top20策略
|
XML 存储 TensorFlow
Tensorflow目标检测接口配合tflite量化模型(一)
Tensorflow目标检测接口配合tflite量化模型
195 0
Tensorflow目标检测接口配合tflite量化模型(一)
|
分布式计算 并行计算 Hadoop
Tensorflow目标检测接口配合tflite量化模型(二)
Tensorflow目标检测接口配合tflite量化模型
354 0
Tensorflow目标检测接口配合tflite量化模型(二)
|
机器学习/深度学习 TensorFlow API
在tensorflow2.2中使用Keras自定义模型的指标度量
在tensorflow2.2中使用Keras自定义模型的指标度量
136 0
在tensorflow2.2中使用Keras自定义模型的指标度量
|
PyTorch TensorFlow API
对比PyTorch和TensorFlow的自动差异和动态子类化模型
对比PyTorch和TensorFlow的自动差异和动态子类化模型
165 0
对比PyTorch和TensorFlow的自动差异和动态子类化模型
|
TensorFlow 算法框架/工具
TensorFlow自己定义EarlyStop回调函数通过监测loss指标
TensorFlow自己定义EarlyStop回调函数通过监测loss指标
114 0