TensorFlow中常见内置回调Callback

简介: TensorFlow中常见内置回调Callback

class BaseLogger:

计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置

callbacks = tf.keras.callbacks.BaseLogger(
    stateful_metrics=None
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class CSVLogger:

将每个epoch的评估及损失结果导入到一个CSV文件中

  • filename:CSV保存路径
  • separator:不同字段之间的分割符
  • append:是否在原来的文件基础之上追加
callbacks = tf.keras.callbacks.CSVLogger(
    filename='./res.log',
    separator=',',
    append=False
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class EarlyStopping:

当一个被监控的指标停止提升的时候停止训练

  • monitor:需要监控的指标或者损失
  • min_delta:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于min_delta则认为两次训练没有任何变化
  • patience:连续没有改进的epoch数,如果连续patience个epoch还没有改进,则停止训练
  • verbose:详细模式,用户打印控制台日志
  • mode:有三种模式,分别是minmaxauto,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断
  • baseline:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止
  • restore_best_weights:是否从具有监控指标最佳值的epoch恢复模型权重
callbacks = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-3,
    patience=2,
    verbose=0,
    mode='min',
    baseline=None,
    restore_best_weights=False
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class History:

将训练事件记录到history对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。

模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标

class LearningRateScheduler:

学习率时间表

  • schedule:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • verbose:是否打印学习更新情况
def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)
callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,
                                                     verbose=1)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ModelCheckpoint:

以某个频率保存 Keras 模型或模型权重的回调

  • filename:保存模型或者权重的路径
  • monitor:需要监测的损失或者评估指标
  • verbose:控制台输出状态
  • save_best_only:是否保存最好的模型
  • save_weights_only:是否只保存权重,否则是保存整个模型
  • mode:监控模式,minmaxauto,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断
  • save_freq:两种选择,分别是epochinteger,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次
  • options:其它配置,用于保存模型或者参数
callbacks = tf.keras.callbacks.ModelCheckpoint(
    filename='./save_model',
    monitor='val_loss',
    verbose=1,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    options=None
)
model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ProgbarLogger:

打印精度到标准输出


目录
相关文章
|
TensorFlow 算法框架/工具
TensorFlow自定义回调函数【全局回调、批次、epoch】
TensorFlow自定义回调函数【全局回调、批次、epoch】
195 0
|
23小时前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
51 34
|
19天前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
46 5
|
28天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
75 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
28天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
74 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
28天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
77 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
21天前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
58 0
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
80 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
114 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
4月前
|
机器学习/深度学习 算法 TensorFlow
深入探索强化学习与深度学习的融合:使用TensorFlow框架实现深度Q网络算法及高效调试技巧
【8月更文挑战第31天】强化学习是机器学习的重要分支,尤其在深度学习的推动下,能够解决更为复杂的问题。深度Q网络(DQN)结合了深度学习与强化学习的优势,通过神经网络逼近动作价值函数,在多种任务中表现出色。本文探讨了使用TensorFlow实现DQN算法的方法及其调试技巧。DQN通过神经网络学习不同状态下采取动作的预期回报Q(s,a),处理高维状态空间。
66 1