TensorFlow自己定义EarlyStop回调函数通过监测loss指标

简介: TensorFlow训练模型需要经过多个epoch,但是并不是epoch越多越好,很有可能训练一半的epoch时,模型的效果开始下降,这是我们需要停止训练,及时的保存模型,为了完成这种需求我们可以自定义回调函数,自动检测模型的损失,只要达到一定阈值我们手动让模型停止训练

TensorFlow训练模型需要经过多个epoch,但是并不是epoch越多越好,很有可能训练一半的epoch时,模型的效果开始下降,这是我们需要停止训练,及时的保存模型,为了完成这种需求我们可以自定义回调函数,自动检测模型的损失,只要达到一定阈值我们手动让模型停止训练


完整代码


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/4

* 时间: 10:32

* 描述:

"""

import numpy as np

import tensorflow as tf

from keras import Model

from tensorflow import keras

from tensorflow.keras.layers import *



def get_model():

   inputs = Input(shape=(784,))

   outputs = Dense(1)(inputs)

   model = Model(inputs, outputs)

   model.compile(

       optimizer=keras.optimizers.RMSprop(learning_rate=0.1),

       loss='mean_squared_error',

       metrics=['mean_absolute_error']

   )

   return model



(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784).astype('float32') / 255.0

x_test = x_test.reshape(-1, 784).astype('float32') / 255.0


x_train = x_train[:1000]

y_train = y_train[:1000]

x_test = x_test[:1000]

y_test = y_test[:1000]



class CustomEarlyStoppingAtMinLoss(keras.callbacks.Callback):

   def __init__(self, patience=0):

       super(CustomEarlyStoppingAtMinLoss, self).__init__()

       self.patience = patience

       self.best_weights = None

       self.wait = 0

       self.stopped_epoch = 0

       self.best = np.Inf


   def on_train_begin(self, logs=None):

       pass


   def on_epoch_end(self, epoch, logs=None):

       current = logs.get("loss")

       if np.less(current, self.best):

           self.best = current

           self.wait = 0

           self.best_weights = self.model.get_weights()

       else:

           self.wait += 1

           if self.wait >= self.patience:

               self.stopped_epoch = epoch

               self.model.stop_training = True

               print("Restoring model weights from the end of the best epoch.")

               self.model.set_weights(self.best_weights)


   def on_train_end(self, logs=None):

       if self.stopped_epoch > 0:

           print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))



model = get_model()

model.fit(

   x_train,

   y_train,

   batch_size=128,

   epochs=10,

   verbose=1,

   validation_split=0.5,

   callbacks=[CustomEarlyStoppingAtMinLoss()],

)

目录
相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:自动微分模块
探索PyTorch:自动微分模块
|
4月前
|
API 算法框架/工具
【Tensorflow+keras】使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
使用keras API保存模型权重、plot画loss损失函数、保存训练loss值
34 0
|
6月前
|
机器学习/深度学习 存储 PyTorch
Pytorch-自动微分模块
PyTorch的torch.autograd模块提供了自动微分功能,用于深度学习中的梯度计算。它包括自定义操作的函数、构建计算图、数值梯度检查、错误检测模式和梯度模式设置等组件。张量通过设置`requires_grad=True`来追踪计算,`backward()`用于反向传播计算梯度,`grad`属性存储张量的梯度。示例展示了如何计算标量和向量张量的梯度,并通过`torch.no_grad()`等方法控制梯度计算。在优化过程中,梯度用于更新模型参数。注意,使用numpy转换要求先`detach()`以避免影响计算图。
|
7月前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
数据格式
重写transformers.Trainer的compute_metrics方法计算评价指标时,形参如何包含自定义的数据
  这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。
582 0
|
机器学习/深度学习 索引 Python
python机器学习classification_report()函数 输出模型评估报告
python机器学习classification_report()函数 输出模型评估报告
2053 0
python机器学习classification_report()函数 输出模型评估报告
|
XML 存储 TensorFlow
Tensorflow目标检测接口配合tflite量化模型(一)
Tensorflow目标检测接口配合tflite量化模型
211 0
Tensorflow目标检测接口配合tflite量化模型(一)
|
分布式计算 并行计算 Hadoop
Tensorflow目标检测接口配合tflite量化模型(二)
Tensorflow目标检测接口配合tflite量化模型
370 0
Tensorflow目标检测接口配合tflite量化模型(二)
|
机器学习/深度学习 TensorFlow API
在tensorflow2.2中使用Keras自定义模型的指标度量
在tensorflow2.2中使用Keras自定义模型的指标度量
152 0
在tensorflow2.2中使用Keras自定义模型的指标度量