TensorFlow中常见内置回调Callback

简介: TensorFlow中常见内置回调Callback

class BaseLogger:

计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置

callbacks = tf.keras.callbacks.BaseLogger(
    stateful_metrics=None
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class CSVLogger:

将每个epoch的评估及损失结果导入到一个CSV文件中

  • filename:CSV保存路径
  • separator:不同字段之间的分割符
  • append:是否在原来的文件基础之上追加
callbacks = tf.keras.callbacks.CSVLogger(
    filename='./res.log',
    separator=',',
    append=False
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class EarlyStopping:

当一个被监控的指标停止提升的时候停止训练

  • monitor:需要监控的指标或者损失
  • min_delta:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于min_delta则认为两次训练没有任何变化
  • patience:连续没有改进的epoch数,如果连续patience个epoch还没有改进,则停止训练
  • verbose:详细模式,用户打印控制台日志
  • mode:有三种模式,分别是minmaxauto,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断
  • baseline:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止
  • restore_best_weights:是否从具有监控指标最佳值的epoch恢复模型权重
callbacks = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-3,
    patience=2,
    verbose=0,
    mode='min',
    baseline=None,
    restore_best_weights=False
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class History:

将训练事件记录到history对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。

模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标

class LearningRateScheduler:

学习率时间表

  • schedule:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • verbose:是否打印学习更新情况
def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)
callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,
                                                     verbose=1)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ModelCheckpoint:

以某个频率保存 Keras 模型或模型权重的回调

  • filename:保存模型或者权重的路径
  • monitor:需要监测的损失或者评估指标
  • verbose:控制台输出状态
  • save_best_only:是否保存最好的模型
  • save_weights_only:是否只保存权重,否则是保存整个模型
  • mode:监控模式,minmaxauto,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断
  • save_freq:两种选择,分别是epochinteger,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次
  • options:其它配置,用于保存模型或者参数
callbacks = tf.keras.callbacks.ModelCheckpoint(
    filename='./save_model',
    monitor='val_loss',
    verbose=1,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    options=None
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ProgbarLogger:

打印精度到标准输出


目录
相关文章
|
TensorFlow 算法框架/工具
TensorFlow自定义回调函数【全局回调、批次、epoch】
TensorFlow自定义回调函数【全局回调、批次、epoch】
180 0
|
4月前
|
机器学习/深度学习 TensorFlow API
TensorFlow与Keras实战:构建深度学习模型
本文探讨了TensorFlow和其高级API Keras在深度学习中的应用。TensorFlow是Google开发的高性能开源框架,支持分布式计算,而Keras以其用户友好和模块化设计简化了神经网络构建。通过一个手写数字识别的实战案例,展示了如何使用Keras加载MNIST数据集、构建CNN模型、训练及评估模型,并进行预测。案例详述了数据预处理、模型构建、训练过程和预测新图像的步骤,为读者提供TensorFlow和Keras的基础实践指导。
341 59
|
4月前
|
机器学习/深度学习 人工智能 算法
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
海洋生物识别系统。以Python作为主要编程语言,通过TensorFlow搭建ResNet50卷积神经网络算法,通过对22种常见的海洋生物('蛤蜊', '珊瑚', '螃蟹', '海豚', '鳗鱼', '水母', '龙虾', '海蛞蝓', '章鱼', '水獭', '企鹅', '河豚', '魔鬼鱼', '海胆', '海马', '海豹', '鲨鱼', '虾', '鱿鱼', '海星', '海龟', '鲸鱼')数据集进行训练,得到一个识别精度较高的模型文件,然后使用Django开发一个Web网页平台操作界面,实现用户上传一张海洋生物图片识别其名称。
170 7
海洋生物识别系统+图像识别+Python+人工智能课设+深度学习+卷积神经网络算法+TensorFlow
|
4月前
|
机器学习/深度学习 人工智能 算法
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
乐器识别系统。使用Python为主要编程语言,基于人工智能框架库TensorFlow搭建ResNet50卷积神经网络算法,通过对30种乐器('迪吉里杜管', '铃鼓', '木琴', '手风琴', '阿尔卑斯号角', '风笛', '班卓琴', '邦戈鼓', '卡萨巴', '响板', '单簧管', '古钢琴', '手风琴(六角形)', '鼓', '扬琴', '长笛', '刮瓜', '吉他', '口琴', '竖琴', '沙槌', '陶笛', '钢琴', '萨克斯管', '锡塔尔琴', '钢鼓', '长号', '小号', '大号', '小提琴')的图像数据集进行训练,得到一个训练精度较高的模型,并将其
65 0
【乐器识别系统】图像识别+人工智能+深度学习+Python+TensorFlow+卷积神经网络+模型训练
|
4月前
|
机器学习/深度学习 人工智能 TensorFlow
TensorFlow 是一个由 Google 开发的开源深度学习框架
TensorFlow 是一个由 Google 开发的开源深度学习框架
67 3
|
1月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
51 0
|
1月前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
54 0
|
2月前
|
持续交付 测试技术 jenkins
JSF 邂逅持续集成,紧跟技术热点潮流,开启高效开发之旅,引发开发者强烈情感共鸣
【8月更文挑战第31天】在快速发展的软件开发领域,JavaServer Faces(JSF)这一强大的Java Web应用框架与持续集成(CI)结合,可显著提升开发效率及软件质量。持续集成通过频繁的代码集成及自动化构建测试,实现快速反馈、高质量代码、加强团队协作及简化部署流程。以Jenkins为例,配合Maven或Gradle,可轻松搭建JSF项目的CI环境,通过JUnit和Selenium编写自动化测试,确保每次构建的稳定性和正确性。
53 0
|
2月前
|
测试技术 数据库
探索JSF单元测试秘籍!如何让您的应用更稳固、更高效?揭秘成功背后的测试之道!
【8月更文挑战第31天】在 JavaServer Faces(JSF)应用开发中,确保代码质量和可维护性至关重要。本文详细介绍了如何通过单元测试实现这一目标。首先,阐述了单元测试的重要性及其对应用稳定性的影响;其次,提出了提高 JSF 应用可测试性的设计建议,如避免直接访问外部资源和使用依赖注入;最后,通过一个具体的 `UserBean` 示例,展示了如何利用 JUnit 和 Mockito 框架编写有效的单元测试。通过这些方法,不仅能够确保代码质量,还能提高开发效率和降低维护成本。
47 0
|
2月前
|
UED 开发者
哇塞!Uno Platform 数据绑定超全技巧大揭秘!从基础绑定到高级转换,优化性能让你的开发如虎添翼
【8月更文挑战第31天】在开发过程中,数据绑定是连接数据模型与用户界面的关键环节,可实现数据自动更新。Uno Platform 提供了简洁高效的数据绑定方式,使属性变化时 UI 自动同步更新。通过示例展示了基本绑定方法及使用 `Converter` 转换数据的高级技巧,如将年龄转换为格式化字符串。此外,还可利用 `BindingMode.OneTime` 提升性能。掌握这些技巧能显著提高开发效率并优化用户体验。
53 0

热门文章

最新文章