混淆矩阵的生成

简介: 混淆矩阵的生成

混淆矩阵简介


混淆矩阵(Confusion Matrix)是一个二维表格,常用于评价分类模型的性能。在混淆矩阵中,每一列代表了预测值,每一行代表了真实值。因此,混淆矩阵中的每一个元素表示了一个样本被预测为某一类别的次数。混淆矩阵的构成如下:

预测值=正例 预测值=反例
真实值=正例 TP FN
真实值=反例 FP TN


其中,TP表示真正例(True Positive),FN表示假反例(False Negative),FP表示假正例(False Positive),TN表示真反例(True Negative)。


解释如下:


TP:真正例,指的是模型将正例预测为正例的次数;


FN:假反例,指的是模型将正例预测为反例的次数;


FP:假正例,指的是模型将反例预测为正例的次数;


TN:真反例,指的是模型将反例预测为反例的次数。


混淆矩阵的重要性在于,可以通过计算其中的四个元素,得到各种评价指标,如精确度(Accuracy)、召回率(Recall)、准确率(Precision)和 F1 值等。


精确度(Accuracy):表示模型预测正确的样本数与总样本数之比,即Accuracy=TP+FP+FN+TNTP+TN;


召回率(Recall):表示模型正确预测正例样本的比例,即 Recall=TP+FNTP;

准确率(Precision):表示模型预测为正例的样本中,真正例的比例,即 Precision=TP+FPTP;


F1 值:综合了准确率和召回率,即F1=Precision+Recall2×Precision×Recall。


混淆矩阵也可以可视化,可以使用热力图等图形来展示混淆矩阵中每个元素的数值大小,以便更加直观地理解分类模型的性能。



混淆矩阵的主要作用和意义如下:


评估分类器的性能:混淆矩阵可以帮助我们计算分类器的准确率、召回率、精确率、F1分数等指标,从而评估分类器的性能。


比较不同分类器的性能:混淆矩阵可以帮助我们比较不同分类器的性能,找出最优的分类器。


识别分类器的错误类型:混淆矩阵可以帮助我们了解分类器在哪些情况下容易出错,识别出分类器的错误类型,从而针对性地改进分类器。


优化分类器的阈值:混淆矩阵可以帮助我们优化分类器的阈值,从而提高分类器的性能。

可视化分类器的性能:混淆矩阵可以将分类器的性能可视化,从而更直观地了解分类器的性能。


混淆矩阵可视化代码:

import os
from matplotlib.font_manager import FontProperties
import itertools
import matplotlib.pyplot as plt
import numpy as np
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
    else:
        print('显示具体数字:')
        print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
    plt.ylim(len(classes) - 0.5, -0.5)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
cnf_matrix = np.array([[151, 64, 731, 164, 45],
                       [821, 653, 79, 0, 28],
                       [266, 167, 423, 4, 2],
                       [691, 0, 107, 776, 26],
                       [30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
# 归一化
# plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
# 不归一化
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')

其中上述有两种方式可以选择,即一种是归一化,一种是不归一化

归一化设置 normalize=True

结果为:


不归一化设置 normalize=False

结果为:

如果想要配合模型生成混淆矩阵,则需要让神经生成一个混淆矩阵的矩阵序列代码为:


import os
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import MobileNetV2
class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels
    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1
    def plot(self, normalize=False):
        if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
      else:
        print('显示具体数字:')
        print(cm)
        matrix = self.matrix
      plt.imshow(matrix , interpolation='nearest', cmap=cmap)
      plt.title(title)
      plt.colorbar()
      tick_marks = np.arange(len(classes))
      plt.xticks(tick_marks, classes, rotation=45)
      plt.yticks(tick_marks, classes)
      # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
      plt.ylim(len(classes) - 0.5, -0.5)
      fmt = '.2f' if normalize else 'd'
      thresh = cm.max() / 2.
      for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
          plt.text(j, i, format(cm[i, j], fmt),
                   horizontalalignment="center",
                   color="white" if cm[i, j] > thresh else "black")
      plt.tight_layout()
      plt.ylabel('True label')
      plt.xlabel('Predicted label')
      plt.show()
if __name__ == '__main__':
    mylabel = {"4": "4", "5": "5", "6": "6"}
    num_classes=3 #################################
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
    ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    validate_dataset = datasets.ImageFolder(root=os.path.join(ROOT_DATA, "val"),
                                            transform=data_transform)
    batch_size = 16
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)
    net = MobileNetV2(num_classes=num_classes)  ###########################
    # load pretrain weights
    model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)
    labels = [label for _, label in mylabel.items()]
    confusion = ConfusionMatrix(num_classes=num_classes, labels=labels) 
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # print('outputs++'+str(outputs.to("cpu").numpy())+'val_labels++'+str(val_labels.numpy()))
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()



其中*多的地方需要自行修改,例如

ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################


在这里进行数据集的修改

mylabel = {"4": "4", "5": "5", "6": "6"}
• 1


进行标签的修改

net = MobileNetV2(num_classes=3)  ###########################


在这里进行网络修改

model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
• 1

在这里进行本地模型权重的修改


相关文章
|
2月前
|
机器学习/深度学习 数据可视化
混淆矩阵与 ROC 曲线:何时使用哪个进行模型评估
混淆矩阵与 ROC 曲线:何时使用哪个进行模型评估
61 11
|
7月前
|
机器学习/深度学习
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC(下)
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC
|
3月前
|
机器学习/深度学习
R语言模型评估:深入理解混淆矩阵与ROC曲线
【9月更文挑战第2天】混淆矩阵和ROC曲线是评估分类模型性能的两种重要工具。混淆矩阵提供了模型在不同类别上的详细表现,而ROC曲线则通过综合考虑真正率和假正率来全面评估模型的分类能力。在R语言中,利用`caret`和`pROC`等包可以方便地实现这两种评估方法,从而帮助我们更好地理解和选择最适合当前任务的模型。
|
6月前
|
机器学习/深度学习 数据可视化 C语言
多分类混淆矩阵详解
多分类混淆矩阵详解
426 0
|
7月前
|
机器学习/深度学习 数据可视化 计算机视觉
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC(上)
数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC
|
7月前
Ridge,Lasso,Elasticnet回归
这篇文章探讨了多元线性回归与正则化的结合,包括Ridge、Lasso和Elasticnet回归。Ridge回归通过添加L2惩罚项提高模型鲁棒性,但可能牺牲一些准确性。Lasso回归引入L1范数,对异常值更敏感,能进行特征选择。Elasticnet结合L1和L2范数,允许在正则化中平衡两者。通过调整α和l1_ratio参数,可以控制整体正则化强度和正则化类型的比例。
68 0
|
7月前
|
机器学习/深度学习 数据可视化
深入了解多分类混淆矩阵:解读、应用与实例
深入了解多分类混淆矩阵:解读、应用与实例
深入了解多分类混淆矩阵:解读、应用与实例
|
7月前
|
机器学习/深度学习 数据采集 算法
R语言逻辑回归、GAM、LDA、KNN、PCA主成分分析分类预测房价及交叉验证|数据分享
R语言逻辑回归、GAM、LDA、KNN、PCA主成分分析分类预测房价及交叉验证|数据分享
|
7月前
|
机器学习/深度学习 数据可视化 数据挖掘
R语言用逻辑回归预测BRFSS中风数据、方差分析anova、ROC曲线AUC、可视化探索
R语言用逻辑回归预测BRFSS中风数据、方差分析anova、ROC曲线AUC、可视化探索
|
7月前
|
机器学习/深度学习 数据可视化 算法
多项式Logistic逻辑回归进行多类别分类和交叉验证准确度箱线图可视化
多项式Logistic逻辑回归进行多类别分类和交叉验证准确度箱线图可视化