class BaseLogger:
计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置
callbacks = tf.keras.callbacks.BaseLogger(
stateful_metrics=None
)
model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)
class CSVLogger:
将每个epoch的评估及损失结果导入到一个CSV文件中
filename
:CSV保存路径separator
:不同字段之间的分割符append
:是否在原来的文件基础之上追加
callbacks = tf.keras.callbacks.CSVLogger(
filename='./res.log',
separator=',',
append=False
)
model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)
class EarlyStopping:
当一个被监控的指标停止提升的时候停止训练
monitor
:需要监控的指标或者损失min_delta
:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于min_delta
则认为两次训练没有任何变化patience
:连续没有改进的epoch数,如果连续patience
个epoch还没有改进,则停止训练verbose
:详细模式,用户打印控制台日志mode
:有三种模式,分别是min
,max
,auto
,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断baseline
:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止restore_best_weights
:是否从具有监控指标最佳值的epoch恢复模型权重
callbacks = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=1e-3,
patience=2,
verbose=0,
mode='min',
baseline=None,
restore_best_weights=False
)
model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)
class History:
将训练事件记录到history
对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。
模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标
class LearningRateScheduler:
学习率时间表
schedule
:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。verbose
:是否打印学习更新情况
def scheduler(epoch, lr):
if epoch < 10:
return lr
else:
return lr * tf.math.exp(-0.1)
callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,
verbose=1)
model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)
class ModelCheckpoint:
以某个频率保存 Keras 模型或模型权重的回调
filename
:保存模型或者权重的路径monitor
:需要监测的损失或者评估指标verbose
:控制台输出状态save_best_only
:是否保存最好的模型save_weights_only
:是否只保存权重,否则是保存整个模型mode
:监控模式,min
,max
,auto
,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断save_freq
:两种选择,分别是epoch
和integer
,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次options
:其它配置,用于保存模型或者参数
callbacks = tf.keras.callbacks.ModelCheckpoint(
filename='./save_model',
monitor='val_loss',
verbose=1,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None
)
model.fit(
train_data,
labels,
epochs=5,
batch_size=32,
validation_split=0.2,
callbacks=callbacks
)
class ProgbarLogger:
打印精度到标准输出