TensorFlow中常见内置回调Callback

简介: 计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置

class BaseLogger:


计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置


callbacks = tf.keras.callbacks.BaseLogger(

   stateful_metrics=None

)


model.fit(

   train_data,

   labels,

   epochs=5,

   batch_size=32,

   validation_split=0.2,

   callbacks=callbacks

)


class CSVLogger:


将每个epoch的评估及损失结果导入到一个CSV文件中


  • filename:CSV保存路径
  • separator:不同字段之间的分割符
  • append:是否在原来的文件基础之上追加


callbacks = tf.keras.callbacks.CSVLogger(

   filename='./res.log',

   separator=',',

   append=False

)


model.fit(

   train_data,

   labels,

   epochs=5,

   batch_size=32,

   validation_split=0.2,

   callbacks=callbacks

)


class EarlyStopping:


当一个被监控的指标停止提升的时候停止训练


  • monitor:需要监控的指标或者损失
  • min_delta:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于min_delta则认为两次训练没有任何变化
  • patience:连续没有改进的epoch数,如果连续patience个epoch还没有改进,则停止训练
  • verbose:详细模式,用户打印控制台日志
  • mode:有三种模式,分别是minmaxauto,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断
  • baseline:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止
  • restore_best_weights:是否从具有监控指标最佳值的epoch恢复模型权重


callbacks = tf.keras.callbacks.EarlyStopping(

   monitor='val_loss',

   min_delta=1e-3,

   patience=2,

   verbose=0,

   mode='min',

   baseline=None,

   restore_best_weights=False

)


model.fit(

   train_data,

   labels,

   epochs=5,

   batch_size=32,

   validation_split=0.2,

   callbacks=callbacks

)


class History:


将训练事件记录到history对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。


模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标


class LearningRateScheduler:


学习率时间表


  • schedule:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • verbose:是否打印学习更新情况


def scheduler(epoch, lr):

   if epoch < 10:

       return lr

   else:

       return lr * tf.math.exp(-0.1)



callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,

                                                    verbose=1)


model.fit(

   train_data,

   labels,

   epochs=5,

   batch_size=32,

   validation_split=0.2,

   callbacks=callbacks

)


class ModelCheckpoint:


以某个频率保存 Keras 模型或模型权重的回调


  • filename:保存模型或者权重的路径
  • monitor:需要监测的损失或者评估指标
  • verbose:控制台输出状态
  • save_best_only:是否保存最好的模型
  • save_weights_only:是否只保存权重,否则是保存整个模型
  • mode:监控模式,minmaxauto,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断
  • save_freq:两种选择,分别是epochinteger,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次
  • options:其它配置,用于保存模型或者参数


callbacks = tf.keras.callbacks.ModelCheckpoint(

   filename='./save_model',

   monitor='val_loss',

   verbose=1,

   save_best_only=False,

   save_weights_only=False,

   mode='auto',

   save_freq='epoch',

   options=None

)


model.fit(

   train_data,

   labels,

   epochs=5,

   batch_size=32,

   validation_split=0.2,

   callbacks=callbacks

)


class ProgbarLogger:


打印精度到标准输出

目录
相关文章
|
2月前
|
机器学习/深度学习 IDE API
【Tensorflow+keras】Keras 用Class类封装的模型如何调试call子函数的模型内部变量
该文章介绍了一种调试Keras中自定义Layer类的call方法的方法,通过直接调用call方法并传递输入参数来进行调试。
22 4
|
3月前
|
机器学习/深度学习 人工智能 API
LangChain之模型调用
LangChain的模型是框架中的核心,基于语言模型构建,用于开发LangChain应用。通过API调用大模型来解决问题是LangChain应用开发的关键过程。
78 1
|
5月前
|
前端开发 Python
探索Python中的异步编程:从回调到async/await
本文将深入探讨Python中的异步编程模式,从最初的回调函数到现代的async/await语法。我们将介绍异步编程的基本概念,探讨其在Python中的实现方式,以及如何使用asyncio库和async/await语法来简化异步代码的编写。通过本文,读者将能够全面了解Python中的异步编程,并掌握使用异步技术构建高效、响应式应用程序的方法。
|
TensorFlow API 算法框架/工具
TensorFlow利用函数API实现简易自编码器
TensorFlow利用函数API实现简易自编码器
63 0
TensorFlow利用函数API实现简易自编码器
|
TensorFlow 算法框架/工具
Tensorflow 出现 ‘Tensor‘ object is not callable解决办法
Tensorflow 出现 ‘Tensor‘ object is not callable解决办法
279 0
Tensorflow 出现 ‘Tensor‘ object is not callable解决办法
|
监控 TensorFlow 算法框架/工具
TensorFlow中常见内置回调Callback
TensorFlow中常见内置回调Callback
130 0
|
TensorFlow 算法框架/工具
TensorFlow自定义回调函数【全局回调、批次、epoch】
TensorFlow自定义回调函数【全局回调、批次、epoch】
177 0
es6 generator 生成器学习总结 使用生成器实现异步请求, async await 的前身
es6 generator 生成器学习总结 使用生成器实现异步请求, async await 的前身
|
TensorFlow 算法框架/工具
tensorflow报错:AttributeError: module ‘tensorflow._api.v2.compat.v1‘ has no attribute ‘Sessions‘,亲测有效
tensorflow报错:AttributeError: module ‘tensorflow._api.v2.compat.v1‘ has no attribute ‘Sessions‘,
1476 0
tensorflow报错:AttributeError: module ‘tensorflow._api.v2.compat.v1‘ has no attribute ‘Sessions‘,亲测有效