多分类混淆矩阵详解

简介: 多分类混淆矩阵详解

⭐️ 前言

机器学习和数据科学中,混淆矩阵(Confusion Matrix)是一个重要的工具,用于评估分类模型的性能。特别是在多分类问题中,混淆矩阵能够清晰地展示模型在每个类别上的预测结果。以下是对多分类混淆矩阵的详细解释。

⭐️ 1. 混淆矩阵的基本概念

混淆矩阵是一个N x N的矩阵(N代表类别数量),它的每一行代表一个实际类别,每一列代表一个预测类别。矩阵中的每个元素C[i][j]表示实际为第i类但被预测为第j类的样本数量。

在多分类问题中,混淆矩阵的结构如下:

真实值\预测值 Predicted: 0 Predicted: 1 Predicted: N-1
Actual: 0 C[0][0] C[0][1] C[0][N-1]
Actual: 1 C[1][0] C[1][1] C[1][N-1]
Actual: N-1 C[N-1][0] C[N-1][1] C[N-1][N-1]

⭐️ 2. 混淆矩阵中的重要指标

真正例(True Positives, TP):实际为正例且预测为正例的样本数量,对应矩阵的对角线元素C[i][i]。

假正例(False Positives, FP):实际为负例但预测为正例的样本数量,对应矩阵非对角线上的元素C[i][j](i ≠ j)。

真负例(True Negatives, TN):在多分类问题中通常不直接计算,但在二分类问题中用于表示实际为负例且预测为负例的样本数量。

假负例(False Negatives, FN):实际为正例但预测为负例的样本数量,在多分类问题中,这通常表示被错误分类到其他类别的样本。

⭐️ 3. 从混淆矩阵计算评估指标

⭐️ 4. 使用Python计算混淆矩阵和评估指标

在Python中,我们可以使用sklearn.metrics模块来计算混淆矩阵和评估指标。以下是一个简单的示例:

from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score  
from sklearn.model_selection import train_test_split  
from sklearn.datasets import load_iris  
from sklearn.svm import SVC  
  
# 加载数据集  
iris = load_iris()  
X = iris.data  
y = iris.target  
  
# 划分数据集为训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  
  
# 训练模型(这里使用SVM作为示例)  
clf = SVC(kernel='linear', C=1, random_state=42)  
clf.fit(X_train, y_train)  
  
# 对测试集进行预测  
y_pred = clf.predict(X_test)  
  
 # 计算混淆矩阵  
cm = confusion_matrix(y_test, y_pred)  
print("Confusion Matrix:")  
print(cm)  
  
# 计算准确率  
accuracy = accuracy_score(y_test, y_pred)  
print("Accuracy:", accuracy)  
  
# 计算每个类别的精准率、召回率和F1分数  
precision = precision_score(y_test, y_pred, average=None)  
recall = recall_score(y_test, y_pred, average=None)  
f1 = f1_score(y_test, y_pred, average=None)  
# 打印每个类别的精准率、召回率和F1分数  
print("Precision per class:", precision)  
print("Recall per class:", recall)  
print("F1 Score per class:", f1)  
  
# 如果你想获得一个全局的评估指标,可以计算它们的平均值
precision_avg = precision_score(y_test, y_pred, average='macro')  # 宏平均  
recall_avg = recall_score(y_test, y_pred, average='macro')  
f1_avg = f1_score(y_test, y_pred, average='macro')  
  
print("Macro average Precision:", precision_avg)  
print("Macro average Recall:", recall_avg)  
print("Macro average F1 Score:", f1_avg)
import seaborn as sns  
import matplotlib.pyplot as plt  
  
# 绘制混淆矩阵  
plt.figure(figsize=(10, 7))  
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')  
plt.xlabel('Predicted')  
plt.ylabel('Truth')  
plt.show()

运行结果如下

Confusion Matrix:
[[10  0  0]
 [ 0  9  0]
 [ 0  0 11]]
Accuracy: 1.0
Precision per class: [1. 1. 1.]
Recall per class: [1. 1. 1.]
F1 Score per class: [1. 1. 1.]
Macro average Precision: 1.0
Macro average Recall: 1.0
Macro average F1 Score: 1.0

混淆矩阵绘制如下

⭐️ 5. 解读混淆矩阵

在混淆矩阵中,对角线上的值(C[i][i])越大越好,因为它们表示正确分类的样本数量。而非对角线上的值越小越好,因为它们表示错误分类的样本数量。

如果某个类别的TP值很低,而FN值很高,那么说明模型在这个类别上的召回率很低,即模型漏掉了很多这个类别的样本。相反,如果某个类别的FP值很高,而TN值(在多分类问题中不直接计算)相对较低,那么说明模型在这个类别上的精准率很低,即模型错误地将很多其他类别的样本预测为这个类别。

⭐️ 总结

混淆矩阵是评估多分类模型性能的有力工具。通过计算混淆矩阵和基于它计算出的评估指标(如准确率、精准率、召回率和F1分数),我们可以全面地了解模型在各个类别上的表现,并据此对模型进行优化。此外,混淆矩阵的可视化可以帮助我们更直观地理解模型的性能。

笔者水平有限,若有不对的地方欢迎评论指正!

相关文章
|
6月前
|
算法
KNN分类算法
KNN分类算法
135 47
|
6月前
|
机器学习/深度学习
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC(下)
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC
WK
|
2月前
|
机器学习/深度学习 算法 数据挖掘
什么是逻辑回归分类器
逻辑回归分类器是一种广泛应用于二分类问题的统计方法,它基于线性组合并通过Sigmoid函数将输出映射为概率值进行分类。核心原理包括:线性组合假设函数、Sigmoid函数转换及基于概率阈值的预测。该模型计算高效、解释性强且鲁棒性好,适用于信用评估、医疗诊断、舆情分析和电商推荐等多种场景。利用现有机器学习库如scikit-learn可简化其实现过程。
WK
30 1
|
4月前
|
机器人 计算机视觉 Python
K-最近邻(KNN)分类器
【7月更文挑战第26天】
44 8
|
5月前
|
机器学习/深度学习 数据可视化 Python
Logistic回归(一)
这篇内容是一个关于逻辑回归的教程概览
|
5月前
|
机器学习/深度学习 算法
Logistic回归(二)
Logistic回归,又称对数几率回归,是用于分类问题的监督学习算法。它基于对数几率(log-odds),通过对数转换几率来确保预测值在0到1之间,适合于二分类任务。模型通过Sigmoid函数(S型曲线)将线性预测转化为概率。逻辑回归损失函数常采用交叉熵,衡量模型预测概率分布与真实标签分布的差异。熵和相对熵(KL散度)是评估分布相似性的度量,低熵表示分布更集中,低交叉熵表示模型预测与真实情况更接近。
|
6月前
|
机器学习/深度学习 数据可视化 计算机视觉
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC(上)
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC
|
6月前
|
机器学习/深度学习 数据可视化
深入了解多分类混淆矩阵:解读、应用与实例
深入了解多分类混淆矩阵:解读、应用与实例
深入了解多分类混淆矩阵:解读、应用与实例
|
6月前
|
机器学习/深度学习 数据采集 算法
R语言逻辑回归、GAM、LDA、KNN、PCA主成分分析分类预测房价及交叉验证|数据分享
R语言逻辑回归、GAM、LDA、KNN、PCA主成分分析分类预测房价及交叉验证|数据分享
|
数据可视化
混淆矩阵的生成
混淆矩阵的生成