TensorFlow自定义回调函数【全局回调、批次、epoch】

简介: TensorFlow自定义回调函数【全局回调、批次、epoch】

由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数

所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。

您可以将回调函数的列表(作为关键字参数 callbacks)传递给以下模型方法:

  • keras.Model.fit()
  • keras.Model.evaluate()
  • keras.Model.predict()

回调函数方法概述

全局方法

on_(train|test|predict)begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。
on
(train|test|predict)end(self, logs=None)
在 fit/evaluate/predict 结束时调用。
批次级方法(仅训练)
on
(train|test|predict)batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on
(train|test|predict)_batch_end(self, batch, logs=None)

在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。

周期级方法(仅训练)

on_epoch_begin(self, epoch, logs=None)

在训练期间周期开始时调用。

on_epoch_end(self, epoch, logs=None)

在训练期间周期开始时调用。

完整代码

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 10:02
 * 描述:
"""
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *
def get_model():
    inputs = Input(shape=(784,))
    outputs = Dense(1)(inputs)
    model = Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss='mean_squared_error',
        metrics=['mean_absolute_error']
    )
    return model
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))
    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))
    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))
    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))
    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))
    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))
    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)
res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[CustomCallback()]
)
res = model.predict(x_test,
                    batch_size=128,
                    callbacks=[CustomCallback()])


目录
相关文章
|
5月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】学习率指数、分段、逆时间、多项式衰减及自定义学习率衰减的完整实例
使用Tensorflow和Keras实现学习率衰减的完整实例,包括指数衰减、分段常数衰减、多项式衰减、逆时间衰减以及如何通过callbacks自定义学习率衰减策略。
86 0
|
8月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow中的自定义层与模型
【4月更文挑战第17天】本文介绍了如何在TensorFlow中创建自定义层和模型。自定义层通过继承`tf.keras.layers.Layer`,实现`__init__`, `build`和`call`方法。例如,一个简单的全连接层`CustomDenseLayer`示例展示了如何定义激活函数。自定义模型则继承自`tf.keras.Model`,在`__init__`中定义层,在`call`中实现前向传播。这两个功能使TensorFlow能应对特定需求和复杂网络结构,增强了其在深度学习应用中的灵活性。
|
机器学习/深度学习 数据可视化 Java
TensorFlow 高级技巧:自定义模型保存、加载和分布式训练
本篇文章将涵盖 TensorFlow 的高级应用,包括如何自定义模型的保存和加载过程,以及如何进行分布式训练。
|
机器学习/深度学习 TensorFlow API
构建自定义机器学习模型:TensorFlow的高级用法
在机器学习领域,TensorFlow已经成为最受欢迎和广泛使用的开源框架之一。它提供了丰富的功能和灵活性,使开发者能够构建各种复杂的机器学习模型。在本文中,我们将深入探讨TensorFlow的高级用法,重点介绍如何构建自定义机器学习模型。
233 0
|
监控 TensorFlow 算法框架/工具
TensorFlow中常见内置回调Callback
TensorFlow中常见内置回调Callback
150 0
|
TensorFlow 算法框架/工具
TensorFlow指定每个epoch验证多少个批次数据集
TensorFlow指定每个epoch验证多少个批次数据集
142 0
|
TensorFlow 算法框架/工具
TensorFlow指定每个epoch训练多少个批次的数据
TensorFlow指定每个epoch训练多少个批次的数据
111 0
|
TensorFlow 算法框架/工具
TensorFlow自定义评估指标
TensorFlow自定义评估指标
229 0
|
19天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
208 55
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
110 5