最近,MMDetection 的新版本 V2.18.1 中加入了社区用户呼唤已久的混淆矩阵绘制功能。
话不多说,先上图!
👇
图1 混淆矩阵示例
怎么样,是不是很好看?
好看是好看,可惜就是有点看不懂(装傻中)
那么本篇文章我们就来详细介绍一下什么是混淆矩阵,以及如何理解目标检测中的混淆矩阵。
1. 什么是混淆矩阵
首先给出定义:在机器学习领域,特别是统计分类问题中,混淆矩阵(confusion matrix)是一种特定的表格布局,用于可视化算法的性能,矩阵的每一行代表实际的类别,而每一列代表预测的类别。
这么说可能有些抽象,那么就先来看一个最简单的例子:二分类的混淆矩阵。
图2 二分类混淆矩阵
上图这张 2 x 2 大小的矩阵就是一个最简单的二分类混淆矩阵,只区分 0 或 1。它的行代表真实的类别,列代表预测的类别。以第一行为例,真正的类别标签是 0,从列方向的预测标签来看,有 8 个实例被预测为了 0,有 2 个实例被预测为了 1。也就是说,在这 10 个真实标签为 0 实例中,有 8 个被正确分类,有 2 个被错误分类。
用同样的思路看第二行,那么就很容易理解了。第二行这 10 个真实标签为 1 的实例中,3 个预测错了,7个预测对了。
别看这个表格只包含四个数字,但其中能表述的含义却非常丰富,通过这四个数字的组合计算,就能够计算出TP,FP,FN 以及 TN,然后衍生出其它更多的模型评估指标。
图3 混淆矩阵的衍生(图片来源:wikipedia)
上图是来自维基百科上的一张表格,可以看到从混淆矩阵中的这些值,可以计算出非常丰富的评价指标,由于篇幅有限,这里不再一一介绍这些指标的含义,感兴趣的读者可以前往原表格中的链接进一步了解。
二分类的混淆矩阵想必大家都理解了,那么把问题拓展到多分类中又是怎样的一种情况呢?
图4 多分类混淆矩阵
上图就是一个四分类的混淆矩阵,与二分类的唯一不同就在于分类的标签不再是非正即负,而是会被预测为更多的类别。如果理解了之前二分类的含义,那么很容易就能理解这张多分类混淆矩阵。
同样以第一行为例,真实的标签是猫猫,但是在这十个猫猫中有一个被误分类为狗,一个被误分类为羊,我们就可以很容易的计算出猫的分类正确率为 80%,也可以很直观的看出有那些类别容易存在误识别。其它行的结果也以此类推,就不再赘述。
2. 目标检测中的混淆矩阵
经过上面的讲解,想必大家对分类任务中的混淆矩阵已经非常理解了,那么我们就把目光转向另一个任务——目标检测。
目标检测中的混淆矩阵与分类中的非常相近,但是区别就在于分类任务的对象是一张张图片,而检测任务不一样,它包含定位与分类两个任务,并且对象是图片中的各个目标。
因此为了能够绘制混淆矩阵中的正负例,就需要去区分检测结果中哪些结果是正确的,哪些结果是错误的,同时,对错误的检测也需要归为不同的错误类别。
图5. 检测类型的判别
让我们来重温一下目标检测中的最基本概念:如何判断一个检测结果是否正确。目前最常用的方式就是去计算检测框与真实框的IOU,然后根据 IOU 去判别两个框是否匹配。以上图第一张为例,红色为模型预测的结果,绿色为真实标注,这两个框的 IOU 大于了阈值,因此被判定为匹配,同时这两个框对应的类别也相同,因此是正确的检测结果(TP)。
第二张图中虽然 IOU 大于了阈值,但由于类别不正确,因此被判别为误检。第三张图的检测框 IOU 小于了阈值,没有与真实标注匹配,因此被判别为背景的误检。第四张图没有检测框,属于漏检(FN)。
图6 目标检测中的混淆矩阵
这些被分门别类的检测结果就可以填充到上图的矩阵中,这就是目标检测中的混淆矩阵。
图7 混淆矩阵中数值的含义
让我们再带着上一章节对分类混淆矩阵的理解来看这张图,就非常容易理解了。同样以第一行为例,在这 12 个真实标签为猫的框中,有 8 个正确识别为了猫,有 1 个被误识别为狗,1 个被误识别为羊,还有两只猫没被识别出来。
通过这些数据,就能够很清晰的看出所测试的模型在检测猫这个目标时的性能了。
3. 使用 MMDetection 绘制混淆矩阵
在理解了什么是混淆矩阵以及如何分析混淆矩阵之后,就可以使用 MMDetection 中提供的小工具,为自己的目标检测模型绘制一个混淆矩阵。
首先,我们需要有一份数据集(包含训练集和验证集)以及在这个数据集的训练集上训练得到的检测模型(本文使用 Pascal VOC 数据集以及 RetinaNet 作为示例)。
然后,我们需要用模型推理验证集中的所有图片,并获取检测结果,具体操作为:
运行 tools/test.py 并获得 检测结果 results.pkl 文件:
python tools/test.py \ ${CONFIG} \ --out results.pkl
对于示例来说,即输入命令:
python tools/test.py \ configs/pascal_voc/retinanet_r50_fpn_1x_voc0712.py \ --out results.pkl
在迭代完成之后就会在当前目录下生成检测结果文件results.pkl。
然后,就可以运行我们的混淆矩阵分析工具来绘制混淆矩阵,具体操作为:
python tools/analysis_tools/confusion_matrix.py \ ${CONFIG} \ ${DETECTION_RESULTS} \ ${SAVE_DIR} \ --show
对于示例来说,即输入命令:
python tools/analysis_tools/confusion_matrix.py \ configs/pascal_voc/retinanet_r50_fpn_1x_voc0712.py \ ./results.pkl \ ./ \ --show
就可以获得一张混淆矩阵图了。
与上文中不一样的是,这张混淆矩阵图是在行方向归一化过的。这是由于检测数据集中的目标过多,每个类别一般都会有成百上千的目标,为了能更好看的显示,同时也为了能够更直观的看出每个类别的识别率和误识别率,这里就对混淆矩阵的每一行中的数值都除以了对应类别的总数进行归一化,以百分比来表示。
以 cat 这一行的结果为例:由于行方向代表真是标签,列方向代表预测的类别,因此就能够从这一行的数值中得到猫的正确检测率有 75%,而被误检为狗的概率有 12%。从最后一列也能看出,有 4% 的猫存在漏检。
除了猫狗之间容易出现误识别,牛和马,公交车和轿车,沙发和椅子之间也都存在误识别。
如果我们单看混淆矩阵的最右边一列,就能够看出每个类别漏检的概率。比如对于这个模型来说,盆栽的漏检高达 29% ,而瓶子的漏检也有 22% 。
另外,如果单看最下面一行,也能够看出不同类别的误报率。其中误报最多的是人这个类别,占所有误报的 33%,其次是椅子,有 11% 的误报。知道了这些信息,我们就能够更有针对性的去优化我们的模型。
4. 总结
看完本文,想必大家都已经对混淆矩阵有了较为全面的理解了,那还等什么,赶紧打开 MMDetection 来给自己的检测模型也画一幅吧!
文章来源:公众号【OpenMMLab】
2021-12-09 20:35