混淆矩阵的生成

简介: 混淆矩阵的生成

混淆矩阵简介


混淆矩阵(Confusion Matrix)是一个二维表格,常用于评价分类模型的性能。在混淆矩阵中,每一列代表了预测值,每一行代表了真实值。因此,混淆矩阵中的每一个元素表示了一个样本被预测为某一类别的次数。混淆矩阵的构成如下:

预测值=正例 预测值=反例
真实值=正例 TP FN
真实值=反例 FP TN


其中,TP表示真正例(True Positive),FN表示假反例(False Negative),FP表示假正例(False Positive),TN表示真反例(True Negative)。


解释如下:


TP:真正例,指的是模型将正例预测为正例的次数;


FN:假反例,指的是模型将正例预测为反例的次数;


FP:假正例,指的是模型将反例预测为正例的次数;


TN:真反例,指的是模型将反例预测为反例的次数。


混淆矩阵的重要性在于,可以通过计算其中的四个元素,得到各种评价指标,如精确度(Accuracy)、召回率(Recall)、准确率(Precision)和 F1 值等。


精确度(Accuracy):表示模型预测正确的样本数与总样本数之比,即Accuracy=TP+FP+FN+TNTP+TN;


召回率(Recall):表示模型正确预测正例样本的比例,即 Recall=TP+FNTP;

准确率(Precision):表示模型预测为正例的样本中,真正例的比例,即 Precision=TP+FPTP;


F1 值:综合了准确率和召回率,即F1=Precision+Recall2×Precision×Recall。


混淆矩阵也可以可视化,可以使用热力图等图形来展示混淆矩阵中每个元素的数值大小,以便更加直观地理解分类模型的性能。



混淆矩阵的主要作用和意义如下:


评估分类器的性能:混淆矩阵可以帮助我们计算分类器的准确率、召回率、精确率、F1分数等指标,从而评估分类器的性能。


比较不同分类器的性能:混淆矩阵可以帮助我们比较不同分类器的性能,找出最优的分类器。


识别分类器的错误类型:混淆矩阵可以帮助我们了解分类器在哪些情况下容易出错,识别出分类器的错误类型,从而针对性地改进分类器。


优化分类器的阈值:混淆矩阵可以帮助我们优化分类器的阈值,从而提高分类器的性能。

可视化分类器的性能:混淆矩阵可以将分类器的性能可视化,从而更直观地了解分类器的性能。


混淆矩阵可视化代码:

import os
from matplotlib.font_manager import FontProperties
import itertools
import matplotlib.pyplot as plt
import numpy as np
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
    else:
        print('显示具体数字:')
        print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
    plt.ylim(len(classes) - 0.5, -0.5)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
cnf_matrix = np.array([[151, 64, 731, 164, 45],
                       [821, 653, 79, 0, 28],
                       [266, 167, 423, 4, 2],
                       [691, 0, 107, 776, 26],
                       [30, 0, 111, 17, 42]])
attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
# 归一化
# plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')
# 不归一化
plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Confusion matrix')

其中上述有两种方式可以选择,即一种是归一化,一种是不归一化

归一化设置 normalize=True

结果为:


不归一化设置 normalize=False

结果为:

如果想要配合模型生成混淆矩阵,则需要让神经生成一个混淆矩阵的矩阵序列代码为:


import os
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import MobileNetV2
class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels
    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1
    def plot(self, normalize=False):
        if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("显示百分比:")
        np.set_printoptions(formatter={'float': '{: 0.2f}'.format})
        print(cm)
      else:
        print('显示具体数字:')
        print(cm)
        matrix = self.matrix
      plt.imshow(matrix , interpolation='nearest', cmap=cmap)
      plt.title(title)
      plt.colorbar()
      tick_marks = np.arange(len(classes))
      plt.xticks(tick_marks, classes, rotation=45)
      plt.yticks(tick_marks, classes)
      # matplotlib版本问题,如果不加下面这行代码,则绘制的混淆矩阵上下只能显示一半,有的版本的matplotlib不需要下面的代码,分别试一下即可
      plt.ylim(len(classes) - 0.5, -0.5)
      fmt = '.2f' if normalize else 'd'
      thresh = cm.max() / 2.
      for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
          plt.text(j, i, format(cm[i, j], fmt),
                   horizontalalignment="center",
                   color="white" if cm[i, j] > thresh else "black")
      plt.tight_layout()
      plt.ylabel('True label')
      plt.xlabel('Predicted label')
      plt.show()
if __name__ == '__main__':
    mylabel = {"4": "4", "5": "5", "6": "6"}
    num_classes=3 #################################
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
    ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    validate_dataset = datasets.ImageFolder(root=os.path.join(ROOT_DATA, "val"),
                                            transform=data_transform)
    batch_size = 16
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=2)
    net = MobileNetV2(num_classes=num_classes)  ###########################
    # load pretrain weights
    model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
    assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)
    labels = [label for _, label in mylabel.items()]
    confusion = ConfusionMatrix(num_classes=num_classes, labels=labels) 
    net.eval()
    with torch.no_grad():
        for val_data in tqdm(validate_loader):
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # print('outputs++'+str(outputs.to("cpu").numpy())+'val_labels++'+str(val_labels.numpy()))
            confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
    confusion.plot()



其中*多的地方需要自行修改,例如

ROOT_DATA = r'D:/other/ClassicalModel/data/flower_datas'  #################################


在这里进行数据集的修改

mylabel = {"4": "4", "5": "5", "6": "6"}
• 1


进行标签的修改

net = MobileNetV2(num_classes=3)  ###########################


在这里进行网络修改

model_weight_path = r"D:/other/ClassicalModel/MobileNet/runs1/mobilenet_v2.pth"  #########################
• 1

在这里进行本地模型权重的修改


相关文章
|
机器学习/深度学习
普通卷积、分组卷积和深度分离卷积概念以及参数量计算
普通卷积、分组卷积和深度分离卷积概念以及参数量计算
2273 0
普通卷积、分组卷积和深度分离卷积概念以及参数量计算
|
7月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
635 6
|
计算机视觉
YOLOv11改进策略【SPPF】| AIFI : 基于Transformer的尺度内特征交互,在降低计算成本的同时提高模型的性能
本文探讨了基于AIFI模块的YOLOv11目标检测改进方法。AIFI是RT-DETR中高效混合编码器的一部分,通过在S5特征层上应用单尺度Transformer编码器,减少计算成本并增强概念实体间的联系,从而提升对象定位和识别效果。实验表明,AIFI使模型延迟降低35%,准确性提高0.4%。
1730 20
YOLOv11改进策略【SPPF】| AIFI : 基于Transformer的尺度内特征交互,在降低计算成本的同时提高模型的性能
|
7月前
|
人工智能 编解码 搜索推荐
AI智能换背景,助力电商图片营销升级
电商产品图换背景是提升销量与品牌形象的关键。传统抠图耗时费力,AI技术则实现一键智能换背景,高效精准。本文详解燕雀光年AI全能设计、Canva、Remove.bg等十大AI工具,涵盖功能特点与选型建议,助力商家快速打造高质量、高吸引力的商品图,提升转化率与品牌价值。(238字)
763 0
|
机器学习/深度学习 数据可视化 PyTorch
深度学习之如何使用Grad-CAM绘制自己的特征提取图-(Pytorch代码,详细注释)神经网络可视化-绘制自己的热力图
深度学习之如何使用Grad-CAM绘制自己的特征提取图-(Pytorch代码,详细注释)神经网络可视化-绘制自己的热力图
深度学习之如何使用Grad-CAM绘制自己的特征提取图-(Pytorch代码,详细注释)神经网络可视化-绘制自己的热力图
|
存储 监控 Linux
在Linux中,列出几种常用的Linux备份工具并说明各自的适用场景。
在Linux中,列出几种常用的Linux备份工具并说明各自的适用场景。
|
前端开发 UED
超越静态:CSS动画轮播图,引领视觉新体验!
超越静态:CSS动画轮播图,引领视觉新体验!
|
开发工具 git
码云 -- 本地代码上传到码云
码云 -- 本地代码上传到码云
921 1
|
数据库 数据安全/隐私保护 数据格式
kik-net、peer、yjk、pkpm、midas、sac地震波格式互相转换
地震波格式转换、时程转换、峰值调整、规范反应谱、计算反应谱、计算持时、生成人工波、时频域转换、数据滤波、基线校正、Arias截波、傅里叶变换、耐震时程曲线、脉冲波合成与提取、三联反应谱、地震动参数、延性反应谱、地震波缩尺、功率谱密度
|
机器学习/深度学习 存储 监控
基于YOLOv8深度学习的高压输电线绝缘子缺陷智能检测系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标检测
基于YOLOv8深度学习的高压输电线绝缘子缺陷智能检测系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标检测