sklearn中分类模型评估指标(一):准确率、Top准确率、平衡准确率

简介: accuracy_score函数计算准确率分数,即预测正确的分数(默认)或计数(当normalize=False时)。在多标签分类中,该函数返回子集准确率(subset accuracy)。 如果样本的整个预测标签集与真实标签集严格匹配,则子集准确率为 1.0; 否则为 0.0。

准确率分数


accuracy_score函数计算准确率分数,即预测正确的分数(默认)或计数(当normalize=False时)。

在多标签分类中,该函数返回子集准确率(subset accuracy)。 如果样本的整个预测标签集与真实标签集严格匹配,则子集准确率为 1.0; 否则为 0.0。

如果y^i\hat{y}_iy^i是第i个样本的预测值和yiy_iyi是对应的真实值,那么正确预测的分数,公式定义如下:

accuracy(y,y^)=1nsamples∑i=0nsamples−11(y^i=yi)=TP+TNTP+FP+TN+FN\texttt{accuracy}(y, \hat{y}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples}-1} 1(\hat{y}_i = y_i)=\frac{TP+TN}{TP+FP+TN+FN}accuracy(y,y^)=nsamples1i=0nsamples11(y^i=yi)=TP+FP+TN+FNTP+TN

其中,1(x)1(x)1(x) 表示指示函数(indicator function),它的含义是:当输入为True的时候,输出为1,输入为False的时候,输出为0。

关于指示函数的说明:

在数学中,指示函数是定义在某集合XXX上的函数,表示其中有哪些元素属于某一子集AAA ,常应用在集合论中。指示函数有时候也称为特征函数。

示例代码如下:

import numpy as np
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
print(accuracy_score(y_true, y_pred))
print(accuracy_score(y_true, y_pred, normalize=False))
复制代码


运行结果:

0.5
2
复制代码


在具有两个类标签指示符矩阵的多标签场景下,示例代码为:

print(accuracy_score(np.array([[0, 1], [1, 1]]), np.ones((2, 2))))
复制代码


运行结果:

0.5
复制代码


Top-k准确率分数


top_k_accuracy_score函数是对accuracy_score函数的扩展。 不同之处在于,只要真实标签与前 k 个最高预测分数之一相关联,就认为预测是正确的。accuracy_score是 k = 1的特例。

该函数可以应用于二分类和多分类情况,但不包括多标签情况。

如果f^i,j\hat{f}_{i,j}f^i,j是对应于第i个样本的第j个最大预测分数的预测类别,yiy_iyi是对应的真实值,那么对于nsamplesn_\text{samples}nsamples个样本,正确预测的分数被定义为

top-k accuracy(y,f^)=1nsamples∑i=0nsamples−1∑j=1k1(f^i,j=yi)\texttt{top-k accuracy}(y, \hat{f}) = \frac{1}{n_\text{samples}} \sum_{i=0}^{n_\text{samples}-1} \sum_{j=1}^{k} 1(\hat{f}_{i,j} = y_i)top-k accuracy(y,f^)=nsamples1i=0nsamples1j=1k1(f^i,j=yi)

其中,k是允许的预测个数, 1(x)1(x)1(x)是指示函数。

示例代码:

import numpy as np
from sklearn.metrics import top_k_accuracy_score
y_true = np.array([0, 1, 2, 2])
# 0, 1, 2
y_score = np.array([[0.5, 0.2, 0.2], # 0,1
                    [0.3, 0.4, 0.2], # 0,1
                    [0.2, 0.4, 0.3], # 1,2
                    [0.7, 0.2, 0.1]]) # 0,1
print(top_k_accuracy_score(y_true, y_score, k=2))
# 如果没有归一化,则返回分类样本预测正确的数量
print(top_k_accuracy_score(y_true, y_score, k=2, normalize=False))
复制代码


运行结果:

0.75
3
复制代码


平衡准确率分数


balance_accuracy_score 函数计算平衡准确率,在二分类和多分类场景中,平衡准确率用来处理不平衡数据集的问题,从而避免对不平衡数据集的评估表现夸大。它被定义为在每个类上的召回率的宏平均值,或者等效于,原始准确率(raw accuracy),其中每个样本根据其真实类别的逆流行程度(逆流行率)进行加权。 因此,对于平衡数据集,其分数等于准确率分数。

在二分类情况下,平衡准确率等于灵敏度(true positive rate,真阳性率)和特异度(true negative rate,真阴性率)的算术平均值,或者二分类情况下,预测的 ROC 曲线下面积而不是分数:

balanced-accuracy=12(TPTP+FN+TNTN+FP)=TPR+TNR2\texttt{balanced-accuracy} = \frac{1}{2}\left( \frac{TP}{TP + FN} + \frac{TN}{TN + FP}\right )=\frac{TPR+TNR}{2}balanced-accuracy=21(TP+FNTP+TN+FPTN)=2TPR+TNR

如果分类器在任一类上具有同等表现,则该术语简化为常规准确率(即正确预测的数量除以预测总数)。

相反,如果仅因为分类器使用了不平衡的测试集,导致常规的准确率高于随机值(chance=1n_classeschance=\frac{1}{n\_classes}chance=n_classes1),那么平衡准确率,这种情况下,将下降到 1n_classes\frac{1}{n\_classes}n_classes1

adjusted=False时,分数范围为从0到1,最佳值为1,最差值为0。当 adjusted=True 时,分数范围从11−n_classes\frac{1}{1 - n\_classes}1n_classes1111(包括边界),随机值分数表现是0(不平衡数据集),完美表现分数是1,完全预测错误分数为11−n_classes\frac{1}{1 - n\_classes}1n_classes1

如果 yiy_iyi是第iii个样本的真实值,并且wiw_iwi是对应的样本权重,那么我们调整样本权重为:

w^i=wi∑j1(yj=yi)wj\hat{w}_i = \frac{w_i}{\sum_j{1(y_j = y_i) w_j}}w^i=j1(yj=yi)wjwi

其中,1(x)1(x)1(x)是指示函数。给定样本 iii的预测y^i\hat{y}_iy^i,则平衡准确率公式定义为:

balanced-accuracy(y,y^,w)=1∑w^i∑i1(y^i=yi)w^i\texttt{balanced-accuracy}(y, \hat{y}, w) = \frac{1}{\sum{\hat{w}_i}} \sum_i 1(\hat{y}_i = y_i) \hat{w}_ibalanced-accuracy(y,y^,w)=w^i1i1(y^i=yi)w^i

adjusted相关源码如下:

if adjusted:
    n_classes = len(per_class) # 类别数
    chance = 1 / n_classes
    score -= chance
    score /= 1 - chance
复制代码


针对二分类情况,示例代码:

from sklearn.metrics import balanced_accuracy_score
y_true = [0, 1, 0, 0, 1, 0]
y_pred = [0, 1, 0, 0, 0, 1]
# tp=1,  fn=1,  tn=3, fp=1
# 常规:(1+3)/6 = 0.66
# 平衡:(1/2+3/4)/2 = 0.625
print(accuracy_score(y_true, y_pred))
print(balanced_accuracy_score(y_true, y_pred))
# 1/类别数
# 0.625 - 1/2  = 0.125
# 0.125 / (1-1/2) = 0.25
print(balanced_accuracy_score(y_true, y_pred, adjusted=True))
复制代码


运行结果:

0.6666666666666666
0.625
0.25
复制代码


针对多分类情况,示例代码如下:

from sklearn.metrics import accuracy_score,balanced_accuracy_score
y_true = [0, 1, 2, 0, 0]
y_pred = [0, 2, 2, 0, 1]
# 3/5
print(accuracy_score(y_true, y_pred))
# 对于0  tp=2  fn=1     2/3
# 对于1  tp=0  fn=1     0
# 对于2  tp=1  fn=0     1
# (2/3+0+1)/3 = 5/9
print(balanced_accuracy_score(y_true, y_pred, adjusted=False))
# 5/9 - 1/3 = 2/9
# (2/9)/(1-1/3) = 1/3
print(balanced_accuracy_score(y_true, y_pred, adjusted=True))
复制代码


运行结果:

0.6
0.5555555555555555
0.3333333333333332
复制代码


针对不平衡数据集的示例代码如下:

from sklearn.metrics import recall_score,balanced_accuracy_score 
def test_balanced_accuracy_score():
    y_true = [0, 1, 2, 0, 0, 1, 4]
    y_pred = [0, 2, 2, 0, 1, 1, 2]
    macro_recall = recall_score(y_true, y_pred, average='macro',
                                labels=np.unique(y_true))
    # adjusted=False时,平衡准确率
    balanced = balanced_accuracy_score(y_true, y_pred)
    print(balanced)
    # adjusted=True时,平衡准确率
    adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True)
    print(adjusted)
    print("-------------")
    # 不平衡数据集(预测的运气值)
    print(np.full_like(y_true, y_true[0]))
    print(np.full_like(y_true, y_true[1]))
    print(np.full_like(y_true, y_true[2]))
    print(np.full_like(y_true, y_true[6]))
    # 不平衡数据集(预测的运气值)
    chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0]))
    print(chance)
    chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[1]))
    print(chance)
    chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[2]))
    print(chance)
    # 从adjusted=False到adjusted=True的转换
    print(adjusted == (balanced - chance) / (1 - chance) )
    print("+++++++++++++")
    # 采用不平衡测试集(adjusted=True),则平衡准确率为0
    print(balanced_accuracy_score(y_true, np.full_like(y_true, y_true[6]), adjusted=False))
    # 采用不平衡测试集(adjusted=True),则平衡准确率为1/(n_classes)
    chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[6]), adjusted=True)
    print(chance)
    print("-------------")
    y_true = [0, 1, 2, 0]
    y_pred = [1, 2, 0, 1]
    # 完全错误的数据集(adjusted=True),则平衡准确率为1/(1 - n_classes)
    print(balanced_accuracy_score(y_true, y_pred, adjusted=True))
    # 完全错误的数据集(adjusted=True),则平衡准确率为0
    print(balanced_accuracy_score(y_true, y_pred, adjusted=False))
test_balanced_accuracy_score()
复制代码


运行结果:

0.5416666666666666
0.38888888888888884
-------------
[0 0 0 0 0 0 0]
[1 1 1 1 1 1 1]
[2 2 2 2 2 2 2]
[4 4 4 4 4 4 4]
0.25
0.25
0.25
True
+++++++++++++
0.25
0.0
-------------
-0.49999999999999994
0.0
复制代码


总结



函数 说明
accuracy_score 适用于二分类、多分类和多标签分类场景。通常用于平衡数据集的场景。
top_k_accuracy_score 适用于二分类、多分类场景。
balance_accuracy_score 适用于二分类、多分类场景。通常用于不平衡数据集的场景。


相关文章
|
机器学习/深度学习 算法 数据挖掘
马尔科夫链(Markov Chain, MC)算法详解及Python实现
马尔科夫链(Markov Chain, MC)算法详解及Python实现
11117 115
马尔科夫链(Markov Chain, MC)算法详解及Python实现
SQL 个版本下载地址
备用:   SQL Server 2016简体中文企业版 文件名:cn_sql_server_2016_enterprise_x64_dvd_8699450.iso 64位下载地址:ed2k://|file|cn_sql_server_2016_enterprise_x64_dvd_8699450.
2538 0
|
机器学习/深度学习 算法
机器学习中最常见的四种分类模型
机器学习中最常见的四种分类模型
1278 10
|
11月前
|
搜索推荐 物联网 PyTorch
Qwen2.5-7B-Instruct Lora 微调
本教程介绍如何基于Transformers和PEFT框架对Qwen2.5-7B-Instruct模型进行LoRA微调。
11927 34
Qwen2.5-7B-Instruct Lora 微调
|
并行计算 PyTorch Linux
大概率(5重方法)解决RuntimeError: CUDA out of memory. Tried to allocate ... MiB
大概率(5重方法)解决RuntimeError: CUDA out of memory. Tried to allocate ... MiB
9803 0
|
机器学习/深度学习 人工智能 自然语言处理
深度剖析深度神经网络(DNN):原理、实现与应用
本文详细介绍了深度神经网络(DNN)的基本原理、核心算法及其具体操作步骤。DNN作为一种重要的人工智能工具,通过多层次的特征学习和权重调节,实现了复杂任务的高效解决。文章通过理论讲解与代码演示相结合的方式,帮助读者理解DNN的工作机制及实际应用。
|
机器学习/深度学习 Serverless Python
`sklearn.metrics`是scikit-learn库中用于评估机器学习模型性能的模块。它提供了多种评估指标,如准确率、精确率、召回率、F1分数、混淆矩阵等。这些指标可以帮助我们了解模型的性能,以便进行模型选择和调优。
`sklearn.metrics`是scikit-learn库中用于评估机器学习模型性能的模块。它提供了多种评估指标,如准确率、精确率、召回率、F1分数、混淆矩阵等。这些指标可以帮助我们了解模型的性能,以便进行模型选择和调优。
|
机器学习/深度学习 PyTorch 算法框架/工具
数据平衡与采样:使用 DataLoader 解决类别不平衡问题
【8月更文第29天】在机器学习项目中,类别不平衡问题非常常见,特别是在二分类或多分类任务中。当数据集中某个类别的样本远少于其他类别时,模型可能会偏向于预测样本数较多的类别,导致少数类别的预测性能较差。为了解决这个问题,可以采用不同的策略来平衡数据集,包括过采样(oversampling)、欠采样(undersampling)以及合成样本生成等方法。本文将介绍如何利用 PyTorch 的 `DataLoader` 来处理类别不平衡问题,并给出具体的代码示例。
2771 2
|
Java
springboot将list封装成csv文件
springboot将list封装成csv文件
235 4
|
机器学习/深度学习 人工智能 文字识别
ultralytics YOLO11 全新发布!(原理介绍+代码详见+结构框图)
本文详细介绍YOLO11,包括其全新特性、代码实现及结构框图,并提供如何使用NEU-DET数据集进行训练的指南。YOLO11在前代基础上引入了新功能和改进,如C3k2、C2PSA模块和更轻量级的分类检测头,显著提升了模型的性能和灵活性。文中还对比了YOLO11与YOLOv8的区别,并展示了训练过程和结果的可视化
19863 0