TensorFlow自定义回调函数【全局回调、批次、epoch】

简介: 由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数

由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数


所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。


您可以将回调函数的列表(作为关键字参数 callbacks)传递给以下模型方法:


  • keras.Model.fit()
  • keras.Model.evaluate()
  • keras.Model.predict()


回调函数方法概述
全局方法
on(train|test|predict)begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。
on(train|test|predict)end(self, logs=None)
在 fit/evaluate/predict 结束时调用。
批次级方法(仅训练)
on(train|test|predict)batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on(train|test|predict)batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。
周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。
on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。


完整代码


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/4

* 时间: 10:02

* 描述:

"""

import tensorflow as tf

from keras import Model

from tensorflow import keras

from tensorflow.keras.layers import *



def get_model():

   inputs = Input(shape=(784,))

   outputs = Dense(1)(inputs)

   model = Model(inputs, outputs)

   model.compile(

       optimizer=keras.optimizers.RMSprop(learning_rate=0.1),

       loss='mean_squared_error',

       metrics=['mean_absolute_error']

   )

   return model



(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784).astype('float32') / 255.0

x_test = x_test.reshape(-1, 784).astype('float32') / 255.0


x_train = x_train[:1000]

y_train = y_train[:1000]

x_test = x_test[:1000]

y_test = y_test[:1000]



class CustomCallback(keras.callbacks.Callback):

   def on_train_begin(self, logs=None):

       keys = list(logs.keys())

       print("Starting training; got log keys: {}".format(keys))


   def on_train_end(self, logs=None):

       keys = list(logs.keys())

       print("Stop training; got log keys: {}".format(keys))


   def on_epoch_begin(self, epoch, logs=None):

       keys = list(logs.keys())

       print("Start epoch {} of training; got log keys: {}".format(epoch, keys))


   def on_epoch_end(self, epoch, logs=None):

       keys = list(logs.keys())

       print("End epoch {} of training; got log keys: {}".format(epoch, keys))


   def on_test_begin(self, logs=None):

       keys = list(logs.keys())

       print("Start testing; got log keys: {}".format(keys))


   def on_test_end(self, logs=None):

       keys = list(logs.keys())

       print("Stop testing; got log keys: {}".format(keys))


   def on_predict_begin(self, logs=None):

       keys = list(logs.keys())

       print("Start predicting; got log keys: {}".format(keys))


   def on_predict_end(self, logs=None):

       keys = list(logs.keys())

       print("Stop predicting; got log keys: {}".format(keys))


   def on_train_batch_begin(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Training: start of batch {}; got log keys: {}".format(batch, keys))


   def on_train_batch_end(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Training: end of batch {}; got log keys: {}".format(batch, keys))


   def on_test_batch_begin(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))


   def on_test_batch_end(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))


   def on_predict_batch_begin(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))


   def on_predict_batch_end(self, batch, logs=None):

       keys = list(logs.keys())

       print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))



model = get_model()

model.fit(

   x_train,

   y_train,

   batch_size=128,

   epochs=1,

   verbose=0,

   validation_split=0.5,

   callbacks=[CustomCallback()],

)


res = model.evaluate(

   x_test,

   y_test,

   batch_size=128,

   verbose=0,

   callbacks=[CustomCallback()]

)


res = model.predict(x_test,

                   batch_size=128,

                   callbacks=[CustomCallback()])

目录
相关文章
|
7月前
|
机器学习/深度学习 存储 并行计算
ModelScope问题之训练报错设置参数如何解决
ModelScope训练是指在ModelScope平台上对机器学习模型进行训练的活动;本合集将介绍ModelScope训练流程、模型优化技巧和训练过程中的常见问题解决方法。
88 0
|
2月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
2月前
|
机器学习/深度学习 存储 数据可视化
以pytorch的forward hook为例探究hook机制
【10月更文挑战第10天】PyTorch 的 Hook 机制允许用户在不修改模型代码的情况下介入前向和反向传播过程,适用于模型可视化、特征提取及梯度分析等任务。通过注册 `forward hook`,可以在模型前向传播过程中插入自定义操作,如记录中间层输出。使用时需注意输入输出格式及计算资源占用。
|
4月前
|
运维 Serverless API
函数计算产品使用问题之如何通过API传递ControlNet参数
函数计算产品作为一种事件驱动的全托管计算服务,让用户能够专注于业务逻辑的编写,而无需关心底层服务器的管理与运维。你可以有效地利用函数计算产品来支撑各类应用场景,从简单的数据处理到复杂的业务逻辑,实现快速、高效、低成本的云上部署与运维。以下是一些关于使用函数计算产品的合集和要点,帮助你更好地理解和应用这一服务。
|
5月前
|
算法
创建一个训练函数
【7月更文挑战第22天】创建一个训练函数。
37 4
|
4月前
|
人工智能
如何让其他模型也能在SemanticKernel中调用本地函数
如何让其他模型也能在SemanticKernel中调用本地函数
46 0
|
7月前
Pyglet控件的批处理参数batch和分组参数group简析
Pyglet控件的批处理参数batch和分组参数group简析
54 0
|
数据格式
重写transformers.Trainer的compute_metrics方法计算评价指标时,形参如何包含自定义的数据
  这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。
582 0
|
Dubbo Java 应用服务中间件
你该不会也觉得Dubbo参数回调中callbacks属性是用来限制回调次数的吧?
前些天,一个同事在使用Dubbo的参数回调时,骂骂咧咧的说,Dubbo的这个回调真是奇葩,居然会限制回调次数,自己不得不把callbacks属性值设置的非常大,但是还是会怕服务运行太久后超过回调次数限制,后续的回调就无法正常执行。
你该不会也觉得Dubbo参数回调中callbacks属性是用来限制回调次数的吧?
|
并行计算 PyTorch 算法框架/工具
详解PyTorch编译并调用自定义CUDA算子的三种方式
详解PyTorch编译并调用自定义CUDA算子的三种方式
971 0