由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数
所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。
您可以将回调函数的列表(作为关键字参数 callbacks)传递给以下模型方法:
- keras.Model.fit()
- keras.Model.evaluate()
- keras.Model.predict()
回调函数方法概述全局方法
on(train|test|predict)begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。
on(train|test|predict)end(self, logs=None)
在 fit/evaluate/predict 结束时调用。批次级方法(仅训练)
on(train|test|predict)batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on(train|test|predict)batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。
on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。
完整代码
"""
* Created with PyCharm
* 作者: 阿光
* 日期: 2022/1/4
* 时间: 10:02
* 描述:
"""
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *
def get_model():
inputs = Input(shape=(784,))
outputs = Dense(1)(inputs)
model = Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss='mean_squared_error',
metrics=['mean_absolute_error']
)
return model
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[CustomCallback()]
)
res = model.predict(x_test,
batch_size=128,
callbacks=[CustomCallback()])