机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
简介: 机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

Deephub翻译组:Alexander Zhao

癌症检测是不平衡分类问题的一个普遍例子,因为非癌症病例往往比实际癌症病例多得多。

一个典型的不平衡分类数据集是乳腺摄影数据集,这个数据集用于从放射扫描中检测乳腺癌(特别是在乳腺摄影中出现明亮的微钙化簇)。研究人员通过扫描图像,对目标进行分割,然后用计算机视觉算法描述分割对象,从而获得了这一数据集。

由于类别不平衡十分严重,这是一个非常流行的不平衡分类数据集。其中98%的候选图像不是癌症,只有2%被有经验的放射科医生标记为癌症。

在本教程中,您将发现如何开发和评估乳腺癌钼靶摄影数据集的不平衡分类模型。完成本教程后,您将知道:

  • 如何加载和探索数据集,并从中获得预处理数据与选择模型的灵感。
  • 如何使用代价敏感算法评估一组机器学习模型并提高其性能。
  • 如何拟合最终模型并使用它预测特定情况下的类标签。

我们开始吧。

教程概述

本教程分为五个部分,分别是:

  1. 乳腺摄影数据集
  2. 浏览数据集
  3. 模型试验和基准结果
  4. 评估模型
  1. 评估机器学习算法
  2. 评估代价敏感算法
  1. 对新数据进行预测

乳腺摄影数据集

在这个项目中,我们将使用一个典型的不平衡机器学习数据集,即“乳腺摄影”数据集,有时称为“Woods乳腺摄影数据集”。

数据集归Kevin Woods等人所有,有关这项工作可以参考1993年发表的题为“Comparative Evaluation Of Pattern Recognition Techniques For Detection Of Microcalcifications In Mammography”的论文。

这个问题的焦点是通过放射扫描来检测乳腺癌,特别是在乳房X光片上出现的微小钙化团。

该数据集首先从24张已知癌症诊断结果的乳房X光片开始扫描,然后使用图像分割计算机视觉算法对图像进行预处理,从乳腺图像中提取候选目标。这些候选目标被分割后,就会被一位经验丰富的放射科医生手工标记。

Woods等人首先从分割的对象中选取与模式识别最相关的对象并进行特征提取,总共提取了29个特征,这些特征被减少到18个,然后最后减少到7个,具体如下(摘自论文原文):

  • 对象的面积(像素)
  • 对象的平均灰度
  • 对象周长像素的渐变强度
  • 对象中的均方根噪声波动
  • 对比度,也即对象的平均灰度减去对象周围两个像素宽边框的平均值
  • 基于形状描述子的低阶矩

这是一个二分类任务,目的是利用给定分割对象的特征来区分乳腺影片中的微钙化和非微钙化。

  • 非微钙化: 负类,或多数类
  • 微钙化: 正类,或少数类

作者评价与比较了一系列机器学习模型,包括人工神经网络、决策树、k近邻算法等。各个模型用受试者操作特性曲线(ROC)进行评估,并且用曲线下面积(AUC)进行比较。

选择ROC与AUC作为评估指标,其目的是最小化假阳性率(FPR,即特异性Specificity的补数)并最大化真阳性率(TPR,即敏感性Sensitivity),这也即ROC曲线的两轴。选用ROC曲线的另一个原因是可以让研究人员确定一个概率阈值,从而对可以接受的最大FPR与TPR进行权衡(因为随着TPR的上升FPR也不可避免地上升,译者注)。

研究结果表明,线性分类器(原文中应该是一个高斯朴素贝叶斯分类器)的AUC为0.936(100次实验的平均值),表现最好。

接下来,让我们仔细看看数据。

探索数据集

乳腺摄影数据集是一个广泛使用的标准机器学习数据集,用于探索和演示许多专门为不平衡分类设计的技术。一个典型的例子是流行的SMOTE技术。

我们使用的数据集是其中的一个版本,它与原始文件中描述的数据集有一些不同。

首先,下载数据集并将其保存在当前的工作目录中,名为“mammography.csv”

  • 下载乳房摄影数据集(maxomography.csv)

检查文件的内容。文件的前几行应如下所示:


0.23001961,5.0725783,-0.27606055,0.83244412,-0.37786573,0.4803223,'-1'
0.15549112,-0.16939038,0.67065219,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.78441482,-0.44365372,5.6747053,-0.85955255,-0.37786573,-0.94572324,'-1'
0.54608818,0.13141457,-0.45638679,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.10298725,-0.3949941,-0.14081588,0.97970269,-0.37786573,1.0135658,'-1'

我们可以看到数据集有6个输入变量,而不是7个输入变量。有可能从这个版本的数据集中删除了论文中列出的第一个输入变量(用像素描述的对象面积)。

输入变量是数值类型,而目标变量是多数类置为“-1”、少数类置为“1”的字符串。这些值需要分别编码为0和1,以满足分类算法对二进制不平衡分类问题的期望。

可以使用read_csv()这一Pandas函数将数据集加载为DataFrame数据结构,注意指定header=None


...# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)

载入完毕后,我们调用DataFrame的shape方法打印其行列数。


...# summarize the shape of the datasetprint(dataframe.shape)

我们还可以通过使用Counter来确认数据,获取各类的比例。


...# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

把这两步放在一起,完整的载入与确认数据的代码如下:


# load and summarize the datasetfrom pandas import read_csvfrom collections import Counter# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)# summarize the shape of the datasetprint(dataframe.shape)# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

运行示例代码,首先加载数据集并确认行和列的数量,即11183行、6个输入变量和1个目标变量。

然后确认类别分布,我们会观察到严重的类别不平衡:多数类(无癌症)约占98%,少数类(癌症)约占2%。


(11183, 7)Class='-1', Count=10923, Percentage=97.675%Class='1', Count=260, Percentage=2.325%

从反例和正例的比例来看,数据集似乎与SMOTE论文中描述的数据集基本匹配。

A typical mammography dataset might contain 98% normal pixels and 2% abnormal pixels.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

此外,正反例的数量也基本与论文相符。

The experiments were conducted on the mammography dataset. There were 10923 examples in the majority class and 260 examples in the minority class originally.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

我相信这是同一个数据集,尽管我无法解释输入特征数量的不匹配现象,例如我们的数据集中只有6个输入数据,而原始论文中有7个。

我们还可以为每个变量创建直方图来观察输入变量的分布,下面列出了完整的示例。


# create histograms of numeric input variablesfrom pandas import read_csvfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# histograms of all variablesdf.hist()pyplot.show()

运行该示例代码将为数据集中的六个输入变量分别创建一个直方图。我们可以看到,这些变量有不同的取值范围,而且大多数变量都是指数分布的,例如,大多数情况下变量只占据直方图的一列,而其他情况下则留下一个长尾,而最后一个变量则似乎具有双峰分布。

根据我们选择的算法,将数据分布缩放到相同的取值范围是可能是很有用的,也许还需要使用一些幂变换,这将在后文进行讨论。

image.png

数据分布

我们还可以为每对输入变量创建一个散点图,称为散点图矩阵。这有助于我们了解是否有任何变量是相互关联的,或在同一方向上发生变化。我们还可以根据类标签给每个散点图上色。我们将多数类(没有癌症)标记为蓝点,少数类(癌症)标记为红点。

下面列出了完整的示例。


# create pairwise scatter plots of numeric input variablesfrom pandas import read_csvfrom pandas.plotting import scatter_matrixfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# define a mapping of class values to colorscolor_dict = {"'-1'":'blue', "'1'":'red'}# map each row to a color based on the class valuecolors = [color_dict[str(x)] for x in df.values[:, -1]]# pairwise scatter plots of all numerical variablesscatter_matrix(df, diagonal='kde', color=colors)pyplot.show()

运行该示例将创建一个6组*6组的散点图矩阵,用于六个输入变量的相互比较。矩阵的对角线表示每个变量的密度分布。

每一对变量的分布比较都出现了两次,分别位于主对角线元素的左侧与上侧(或右侧与下侧)。这提供了两种查看变量分布的尺度。

我们可以看到,对于正类与负类,许多变量的分布确实不同,这表明在癌症病例和非癌症病例之间进行合理的区分是可行的。

image.png

散点图矩阵

现在我们已经回顾了数据集,接下来让我们来评估与测试备选模型。

目录
相关文章
|
1月前
|
机器人
1024 云上见 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建 “文旅领域知识问答机器人” 领精美计时器
1024 云上见 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建 “文旅领域知识问答机器人” 领精美计时器
86 3
|
1月前
|
机器学习/深度学习 数据采集 监控
如何使用机器学习模型来自动化评估数据质量?
【10月更文挑战第6天】如何使用机器学习模型来自动化评估数据质量?
|
10天前
|
机器学习/深度学习 数据采集 监控
如何使用机器学习模型来自动化评估数据质量?
如何使用机器学习模型来自动化评估数据质量?
|
7天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
24 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
11天前
|
机器学习/深度学习 算法 PyTorch
用Python实现简单机器学习模型:以鸢尾花数据集为例
用Python实现简单机器学习模型:以鸢尾花数据集为例
33 1
|
1月前
|
数据采集 移动开发 数据可视化
模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)
这篇文章介绍了数据清洗、分析、可视化、模型搭建、训练和预测的全过程,包括缺失值处理、异常值处理、特征选择、数据归一化等关键步骤,并展示了模型融合技术。
54 1
模型预测笔记(一):数据清洗分析及可视化、模型搭建、模型训练和预测代码一体化和对应结果展示(可作为baseline)
|
1月前
|
XML JSON 数据可视化
数据集学习笔记(二): 转换不同类型的数据集用于模型训练(XML、VOC、YOLO、COCO、JSON、PNG)
本文详细介绍了不同数据集格式之间的转换方法,包括YOLO、VOC、COCO、JSON、TXT和PNG等格式,以及如何可视化验证数据集。
68 1
数据集学习笔记(二): 转换不同类型的数据集用于模型训练(XML、VOC、YOLO、COCO、JSON、PNG)
|
20天前
|
机器学习/深度学习 数据采集 Python
从零到一:手把手教你完成机器学习项目,从数据预处理到模型部署全攻略
【10月更文挑战第25天】本文通过一个预测房价的案例,详细介绍了从数据预处理到模型部署的完整机器学习项目流程。涵盖数据清洗、特征选择与工程、模型训练与调优、以及使用Flask进行模型部署的步骤,帮助读者掌握机器学习的最佳实践。
59 1
|
23天前
|
机器学习/深度学习 数据采集 监控
如何使用机器学习模型来自动化评估数据质量?
如何使用机器学习模型来自动化评估数据质量?
|
29天前
|
机器人
1024 云上见 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建 “文旅领域知识问答机器人” 领 200个 精美计时器等你领
1024 云上见 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型,搭建 “文旅领域知识问答机器人” 领 200个 精美计时器等你领
73 2

热门文章

最新文章

相关产品

  • 人工智能平台 PAI