TensorFlow自定义评估指标

简介: 有些时候我们的指标不止这些,需要根据我们自己特定的任务指定自己的评估指标,这时就需要自定义Metric,需要子类化Metric,也就是继承keras.metrics.Metric,然后实现它的方法

TensorFlow内置常用指标:


  • AUC()
  • Precision()
  • Recall()
  • 等等


有些时候我们的指标不止这些,需要根据我们自己特定的任务指定自己的评估指标,这时就需要自定义Metric,需要子类化Metric,也就是继承keras.metrics.Metric,然后实现它的方法:


  • __init__:这个方法是用来初始化一些变量的
  • update_state:参数有真实值、预测值,采样权重,我们需要在这个方法内进行更新状态变量
  • result:使用状态变量计算最终的评估结果
  • reset_states:重新初始化状态变量


下方实现的评估指标是计算有多少个正类被评估正确,就是预测对了多少样本


完整代码:


"""

* Created with PyCharm

* 作者: 阿光

* 日期: 2022/1/2

* 时间: 18:32

* 描述:

"""

import tensorflow as tf

import tensorflow.keras.datasets.mnist

from keras import Input, Model

from keras.layers import Dense

from tensorflow import keras


(train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data()


train_images, val_images = train_images / 255.0, val_images / 255.0


train_images = train_images.reshape(60000, 784)

val_images = val_images.reshape(10000, 784)



class CategoricalTruePositives(keras.metrics.Metric):

   def __init__(self, name="categorical_true_positives"):

       super(CategoricalTruePositives, self).__init__(name=name)

       self.true_positives = self.add_weight(name='ctp', initializer='zeros')


   def update_state(self, y_true, y_pred, sample_weight=None):

       y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))

       values = tf.cast(y_true, 'int32') == tf.cast(y_pred, 'int32')

       values = tf.cast(values, 'float32')

       self.true_positives.assign_add(tf.reduce_sum(values))


   def result(self):

       return self.true_positives


   def reset_states(self):

       self.true_positives.assign(0.0)



def get_model():

   inputs = Input(shape=(784,))

   outputs = Dense(10, activation='softmax')(inputs)

   model = Model(inputs, outputs)

   model.compile(

       optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),

       loss=keras.losses.SparseCategoricalCrossentropy(),

       metrics=[CategoricalTruePositives()]

   )

   return model



model = get_model()


model.fit(

   train_images,

   train_labels,

   epochs=5,

   batch_size=32,

   validation_data=(val_images, val_labels)

)

目录
相关文章
|
5月前
|
机器学习/深度学习 搜索推荐 算法
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
推荐系统离线评估方法和评估指标,以及在推荐服务器内部实现A/B测试和解决A/B测试资源紧张的方法。还介绍了如何在TensorFlow中进行模型离线评估实践。
383 0
|
存储 人工智能 Prometheus
ML 模型监控最佳工具(上)
如果您迟早将模型部署到生产环境,那么您将开始寻找 ML 模型监控工具。 当您的 ML 模型影响业务时,您只需要了解“事物是如何工作的”。 当事物停止工作时,你真正感受到这一点的第一刻。如果没有设置模型监控,您可能不知道哪里出了问题以及从哪里开始寻找问题和解决方案。
|
3月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch模型训练与部署流程详解
【7月更文挑战第14天】PyTorch以其灵活性和易用性在模型训练与部署中展现出强大的优势。通过遵循上述流程,我们可以有效地完成模型的构建、训练和部署工作,并将深度学习技术应用于各种实际场景中。随着技术的不断进步和应用的深入,我们相信PyTorch将在未来的机器学习和深度学习领域发挥更加重要的作用。
|
5月前
|
机器学习/深度学习 监控 测试技术
TensorFlow的模型评估与验证
【4月更文挑战第17天】TensorFlow是深度学习中用于模型评估与验证的重要框架,提供多样工具支持这一过程。模型评估衡量模型在未知数据上的表现,帮助识别性能和优化方向。在TensorFlow中,使用验证集和测试集评估模型,选择如准确率、召回率等指标,并通过`tf.keras.metrics`模块更新和获取评估结果。模型验证则确保模型稳定性和泛化能力,常用方法包括交叉验证和留出验证。通过这些方法,开发者能有效提升模型质量和性能。
|
5月前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
机器学习/深度学习 算法 PyTorch
机器学习之PyTorch和Scikit-Learn第6章 学习模型评估和超参数调优的最佳实践Part 2
本节中,我们来看两个非常简单但强大的诊断工具,可帮助我们提升学习算法的性能:学习曲线和验证曲线,在接下的小节中,我们会讨论如何使用学习曲线诊断学习算法是否有过拟合(高方差)或欠拟合(高偏置)的问题。另外,我们还会学习验证曲线,它辅助我们处理学习算法中的常见问题。
326 0
机器学习之PyTorch和Scikit-Learn第6章 学习模型评估和超参数调优的最佳实践Part 2
|
机器学习/深度学习 存储 数据采集
机器学习之PyTorch和Scikit-Learn第6章 学习模型评估和超参数调优的最佳实践Part 1
在前面的章节中,我们学习了用于分类的基本机器学习算法以及如何在喂给这些算法前处理好数据。下面该学习通过调优算法和评估模型表现来构建良好机器学习模型的最佳实践了。本章中,我们将学习如下内容: 评估机器学习模型表现 诊断机器学习算法常见问题 调优机器学习模型 使用不同的性能指标评估预测模型 通过管道流程化工作流
281 0
机器学习之PyTorch和Scikit-Learn第6章 学习模型评估和超参数调优的最佳实践Part 1
|
机器学习/深度学习 数据可视化 JavaScript
Tensorflow的训练流程和部署流程你知道吗?
Tensorflow的训练流程和部署流程你知道吗?
141 0
|
机器学习/深度学习 人工智能 算法
机器学习之PyTorch和Scikit-Learn第6章 学习模型评估和超参数调优的最佳实践Part 3
在前面的章节中,我们使用预测准确率来评估各机器学习模型,通常这是用于量化模型表现很有用的指标。但还有其他几个性能指标可以用于衡量模型的相关性,例如精确率、召回率、F1分数和马修斯相关系数(MCC)等。
333 0
|
XML 存储 TensorFlow
Tensorflow目标检测接口配合tflite量化模型(一)
Tensorflow目标检测接口配合tflite量化模型
195 0
Tensorflow目标检测接口配合tflite量化模型(一)
下一篇
无影云桌面