混淆矩阵的生成

简介: 混淆矩阵的生成

混淆矩阵简介


混淆矩阵(Confusion Matrix)是一个二维表格,常用于评价分类模型的性能。在混淆矩阵中,每一列代表了预测值,每一行代表了真实值。因此,混淆矩阵中的每一个元素表示了一个样本被预测为某一类别的次数。混淆矩阵的构成如下:

预测值=正例 预测值=反例
真实值=正例 TP FN
真实值=反例 FP TN


其中,TP表示真正例(True Positive),FN表示假反例(False Negative),FP表示假正例(False Positive),TN表示真反例(True Negative)。


解释如下:


TP:真正例,指的是模型将正例预测为正例的次数;


FN:假反例,指的是模型将正例预测为反例的次数;


FP:假正例,指的是模型将反例预测为正例的次数;


TN:真反例,指的是模型将反例预测为反例的次数。


混淆矩阵的重要性在于,可以通过计算其中的四个元素,得到各种评价指标,如精确度(Accuracy)、召回率(Recall)、准确率(Precision)和 F1 值等。


精确度(Accuracy):表示模型预测正确的样本数与总样本数之比,即Accuracy=TP+FP+FN+TNTP+TN;


召回率(Recall):表示模型正确预测正例样本的比例,即 Recall=TP+FNTP;

准确率(Precision):表示模型预测为正例的样本中,真正例的比例,即 Precision=TP+FPTP;


F1 值:综合了准确率和召回率,即F1=Precision+Recall2×Precision×Recall。


混淆矩阵也可以可视化,可以使用热力图等图形来展示混淆矩阵中每个元素的数值大小,以便更加直观地理解分类模型的性能。



混淆矩阵的主要作用和意义如下:


评估分类器的性能:混淆矩阵可以帮助我们计算分类器的准确率、召回率、精确率、F1分数等指标,从而评估分类器的性能。


比较不同分类器的性能:混淆矩阵可以帮助我们比较不同分类器的性能,找出最优的分类器。


识别分类器的错误类型:混淆矩阵可以帮助我们了解分类器在哪些情况下容易出错,识别出分类器的错误类型,从而针对性地改进分类器。


优化分类器的阈值:混淆矩阵可以帮助我们优化分类器的阈值,从而提高分类器的性能。

可视化分类器的性能:混淆矩阵可以将分类器的性能可视化,从而更直观地了解分类器的性能。


混淆矩阵可视化代码:

import os
from matplotlib.font_manager import FontProperties
import itertools
import matplotlib.pyplot as plt
import numpy as np
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
    else:
        print('显示具体数字:')
        print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
    plt.ylim(len(classes) - 0.5, -0.5)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
cnf_matrix = np.array([[151, 64, 731, 164, 45],
                       [821, 653, 79, 0, 28],
                       [266, 167, 423, 4, 2],
                       [691, 0, 107, 776, 26],
                       [30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
# 归一化
# plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
# 不归一化
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')

其中上述有两种方式可以选择,即一种是归一化,一种是不归一化

归一化设置 normalize=True

结果为:


不归一化设置 normalize=False

结果为:

如果想要配合模型生成混淆矩阵,则需要让神经生成一个混淆矩阵的矩阵序列代码为:


import os
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import MobileNetV2
class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels
    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1
    def plot(self, normalize=False):
        if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
      else:
        print('显示具体数字:')
        print(cm)
        matrix = self.matrix
      plt.imshow(matrix , interpolation='nearest', cmap=cmap)
      plt.title(title)
      plt.colorbar()
      tick_marks = np.arange(len(classes))
      plt.xticks(tick_marks, classes, rotation=45)
      plt.yticks(tick_marks, classes)
      # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
      plt.ylim(len(classes) - 0.5, -0.5)
      fmt = '.2f' if normalize else 'd'
      thresh = cm.max() / 2.
      for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
          plt.text(j, i, format(cm[i, j], fmt),
                   horizontalalignment="center",
                   color="white" if cm[i, j] > thresh else "black")
      plt.tight_layout()
      plt.ylabel('True label')
      plt.xlabel('Predicted label')
      plt.show()
if __name__ == '__main__':
    mylabel = {"4": "4", "5": "5", "6": "6"}
    num_classes=3 #################################
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
    ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    validate_dataset = datasets.ImageFolder(root=os.path.join(ROOT_DATA, "val"),
                                            transform=data_transform)
    batch_size = 16
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)
    net = MobileNetV2(num_classes=num_classes)  ###########################
    # load pretrain weights
    model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)
    labels = [label for _, label in mylabel.items()]
    confusion = ConfusionMatrix(num_classes=num_classes, labels=labels) 
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # print('outputs++'+str(outputs.to("cpu").numpy())+'val_labels++'+str(val_labels.numpy()))
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()



其中*多的地方需要自行修改,例如

ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################


在这里进行数据集的修改

mylabel = {"4": "4", "5": "5", "6": "6"}
• 1


进行标签的修改

net = MobileNetV2(num_classes=3)  ###########################


在这里进行网络修改

model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
• 1

在这里进行本地模型权重的修改


相关文章
|
机器学习/深度学习 人工智能 自动驾驶
深度学习-数据增强与扩充
深度学习-数据增强与扩充
|
3月前
|
敏捷开发 自然语言处理 IDE
通义灵码+云效 DevOps MCP:通过云效工作项自动生成代码并提交请求
本文将详细介绍如何利用云效MCP服务,根据工作项内容生成对应代码、创建分支、提交代码,并发起合并请求。
|
10月前
|
监控
SMoA: 基于稀疏混合架构的大语言模型协同优化框架
通过引入稀疏化和角色多样性,SMoA为大语言模型多代理系统的发展开辟了新的方向。
319 6
SMoA: 基于稀疏混合架构的大语言模型协同优化框架
|
存储 监控 Linux
在Linux中,列出几种常用的Linux备份工具并说明各自的适用场景。
在Linux中,列出几种常用的Linux备份工具并说明各自的适用场景。
Axure App 垂直滚动
Axure App 垂直滚动
1102 0
|
存储 弹性计算 固态存储
阿里云服务器CPU内存配置详细指南,如何选择合适云服务器配置?
阿里云服务器配置选择涉及CPU、内存、公网带宽和磁盘。个人开发者或中小企业推荐使用轻量应用服务器或ECS经济型e实例,如2核2G3M配置,适合低流量网站。企业用户则应选择企业级独享型ECS,如通用算力型u1、计算型c7或通用型g7,至少2核4G配置,公网带宽建议5M,系统盘可选SSD或ESSD云盘。选择时考虑实际应用需求和性能稳定性。
1864 6
|
机器学习/深度学习 算法 数据建模
决策树(Decision Tree)算法详解及python实现
决策树(Decision Tree)算法详解及python实现
2496 0
决策树(Decision Tree)算法详解及python实现
|
网络安全 数据安全/隐私保护
TortoiseGit✨解决TortoiseGitPlink要求输入密码
TortoiseGit✨解决TortoiseGitPlink要求输入密码
|
开发工具 git
码云 -- 本地代码上传到码云
码云 -- 本地代码上传到码云
622 1
|
Linux
Linux(9)Debain EC25 quectel-CM usbnet0开机自动联网配置
Linux(9)Debain EC25 quectel-CM usbnet0开机自动联网配置
724 0