由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数
所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。
您可以将回调函数的列表(作为关键字参数 callbacks)传递给以下模型方法:
- keras.Model.fit()
- keras.Model.evaluate()
- keras.Model.predict()
回调函数方法概述
全局方法
on_(train|test|predict)begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。
on(train|test|predict)end(self, logs=None)
在 fit/evaluate/predict 结束时调用。批次级方法(仅训练)
on(train|test|predict)batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on(train|test|predict)_batch_end(self, batch, logs=None)在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。
周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。
on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。
完整代码
""" * Created with PyCharm * 作者: 阿光 * 日期: 2022/1/4 * 时间: 10:02 * 描述: """ import tensorflow as tf from keras import Model from tensorflow import keras from tensorflow.keras.layers import * def get_model(): inputs = Input(shape=(784,)) outputs = Dense(1)(inputs) model = Model(inputs, outputs) model.compile( optimizer=keras.optimizers.RMSprop(learning_rate=0.1), loss='mean_squared_error', metrics=['mean_absolute_error'] ) return model (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.reshape(-1, 784).astype('float32') / 255.0 x_test = x_test.reshape(-1, 784).astype('float32') / 255.0 x_train = x_train[:1000] y_train = y_train[:1000] x_test = x_test[:1000] y_test = y_test[:1000] class CustomCallback(keras.callbacks.Callback): def on_train_begin(self, logs=None): keys = list(logs.keys()) print("Starting training; got log keys: {}".format(keys)) def on_train_end(self, logs=None): keys = list(logs.keys()) print("Stop training; got log keys: {}".format(keys)) def on_epoch_begin(self, epoch, logs=None): keys = list(logs.keys()) print("Start epoch {} of training; got log keys: {}".format(epoch, keys)) def on_epoch_end(self, epoch, logs=None): keys = list(logs.keys()) print("End epoch {} of training; got log keys: {}".format(epoch, keys)) def on_test_begin(self, logs=None): keys = list(logs.keys()) print("Start testing; got log keys: {}".format(keys)) def on_test_end(self, logs=None): keys = list(logs.keys()) print("Stop testing; got log keys: {}".format(keys)) def on_predict_begin(self, logs=None): keys = list(logs.keys()) print("Start predicting; got log keys: {}".format(keys)) def on_predict_end(self, logs=None): keys = list(logs.keys()) print("Stop predicting; got log keys: {}".format(keys)) def on_train_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Training: start of batch {}; got log keys: {}".format(batch, keys)) def on_train_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Training: end of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys)) def on_test_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_begin(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys)) def on_predict_batch_end(self, batch, logs=None): keys = list(logs.keys()) print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys)) model = get_model() model.fit( x_train, y_train, batch_size=128, epochs=1, verbose=0, validation_split=0.5, callbacks=[CustomCallback()], ) res = model.evaluate( x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()] ) res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])