机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,每月250计算时 3个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

Deephub翻译组:Alexander Zhao

癌症检测是不平衡分类问题的一个普遍例子,因为非癌症病例往往比实际癌症病例多得多。

一个典型的不平衡分类数据集是乳腺摄影数据集,这个数据集用于从放射扫描中检测乳腺癌(特别是在乳腺摄影中出现明亮的微钙化簇)。研究人员通过扫描图像,对目标进行分割,然后用计算机视觉算法描述分割对象,从而获得了这一数据集。

由于类别不平衡十分严重,这是一个非常流行的不平衡分类数据集。其中98%的候选图像不是癌症,只有2%被有经验的放射科医生标记为癌症。

在本教程中,您将发现如何开发和评估乳腺癌钼靶摄影数据集的不平衡分类模型。完成本教程后,您将知道:

  • 如何加载和探索数据集,并从中获得预处理数据与选择模型的灵感。
  • 如何使用代价敏感算法评估一组机器学习模型并提高其性能。
  • 如何拟合最终模型并使用它预测特定情况下的类标签。

我们开始吧。

教程概述

本教程分为五个部分,分别是:

  1. 乳腺摄影数据集
  2. 浏览数据集
  3. 模型试验和基准结果
  4. 评估模型
  1. 评估机器学习算法
  2. 评估代价敏感算法
  1. 对新数据进行预测

乳腺摄影数据集

在这个项目中,我们将使用一个典型的不平衡机器学习数据集,即“乳腺摄影”数据集,有时称为“Woods乳腺摄影数据集”。

数据集归Kevin Woods等人所有,有关这项工作可以参考1993年发表的题为“Comparative Evaluation Of Pattern Recognition Techniques For Detection Of Microcalcifications In Mammography”的论文。

这个问题的焦点是通过放射扫描来检测乳腺癌,特别是在乳房X光片上出现的微小钙化团。

该数据集首先从24张已知癌症诊断结果的乳房X光片开始扫描,然后使用图像分割计算机视觉算法对图像进行预处理,从乳腺图像中提取候选目标。这些候选目标被分割后,就会被一位经验丰富的放射科医生手工标记。

Woods等人首先从分割的对象中选取与模式识别最相关的对象并进行特征提取,总共提取了29个特征,这些特征被减少到18个,然后最后减少到7个,具体如下(摘自论文原文):

  • 对象的面积(像素)
  • 对象的平均灰度
  • 对象周长像素的渐变强度
  • 对象中的均方根噪声波动
  • 对比度,也即对象的平均灰度减去对象周围两个像素宽边框的平均值
  • 基于形状描述子的低阶矩

这是一个二分类任务,目的是利用给定分割对象的特征来区分乳腺影片中的微钙化和非微钙化。

  • 非微钙化: 负类,或多数类
  • 微钙化: 正类,或少数类

作者评价与比较了一系列机器学习模型,包括人工神经网络、决策树、k近邻算法等。各个模型用受试者操作特性曲线(ROC)进行评估,并且用曲线下面积(AUC)进行比较。

选择ROC与AUC作为评估指标,其目的是最小化假阳性率(FPR,即特异性Specificity的补数)并最大化真阳性率(TPR,即敏感性Sensitivity),这也即ROC曲线的两轴。选用ROC曲线的另一个原因是可以让研究人员确定一个概率阈值,从而对可以接受的最大FPR与TPR进行权衡(因为随着TPR的上升FPR也不可避免地上升,译者注)。

研究结果表明,线性分类器(原文中应该是一个高斯朴素贝叶斯分类器)的AUC为0.936(100次实验的平均值),表现最好。

接下来,让我们仔细看看数据。

探索数据集

乳腺摄影数据集是一个广泛使用的标准机器学习数据集,用于探索和演示许多专门为不平衡分类设计的技术。一个典型的例子是流行的SMOTE技术。

我们使用的数据集是其中的一个版本,它与原始文件中描述的数据集有一些不同。

首先,下载数据集并将其保存在当前的工作目录中,名为“mammography.csv”

  • 下载乳房摄影数据集(maxomography.csv)

检查文件的内容。文件的前几行应如下所示:


0.23001961,5.0725783,-0.27606055,0.83244412,-0.37786573,0.4803223,'-1'
0.15549112,-0.16939038,0.67065219,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.78441482,-0.44365372,5.6747053,-0.85955255,-0.37786573,-0.94572324,'-1'
0.54608818,0.13141457,-0.45638679,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.10298725,-0.3949941,-0.14081588,0.97970269,-0.37786573,1.0135658,'-1'

我们可以看到数据集有6个输入变量,而不是7个输入变量。有可能从这个版本的数据集中删除了论文中列出的第一个输入变量(用像素描述的对象面积)。

输入变量是数值类型,而目标变量是多数类置为“-1”、少数类置为“1”的字符串。这些值需要分别编码为0和1,以满足分类算法对二进制不平衡分类问题的期望。

可以使用read_csv()这一Pandas函数将数据集加载为DataFrame数据结构,注意指定header=None


...# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)

载入完毕后,我们调用DataFrame的shape方法打印其行列数。


...# summarize the shape of the datasetprint(dataframe.shape)

我们还可以通过使用Counter来确认数据,获取各类的比例。


...# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

把这两步放在一起,完整的载入与确认数据的代码如下:


# load and summarize the datasetfrom pandas import read_csvfrom collections import Counter# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)# summarize the shape of the datasetprint(dataframe.shape)# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

运行示例代码,首先加载数据集并确认行和列的数量,即11183行、6个输入变量和1个目标变量。

然后确认类别分布,我们会观察到严重的类别不平衡:多数类(无癌症)约占98%,少数类(癌症)约占2%。


(11183, 7)Class='-1', Count=10923, Percentage=97.675%Class='1', Count=260, Percentage=2.325%

从反例和正例的比例来看,数据集似乎与SMOTE论文中描述的数据集基本匹配。

A typical mammography dataset might contain 98% normal pixels and 2% abnormal pixels.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

此外,正反例的数量也基本与论文相符。

The experiments were conducted on the mammography dataset. There were 10923 examples in the majority class and 260 examples in the minority class originally.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

我相信这是同一个数据集,尽管我无法解释输入特征数量的不匹配现象,例如我们的数据集中只有6个输入数据,而原始论文中有7个。

我们还可以为每个变量创建直方图来观察输入变量的分布,下面列出了完整的示例。


# create histograms of numeric input variablesfrom pandas import read_csvfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# histograms of all variablesdf.hist()pyplot.show()

运行该示例代码将为数据集中的六个输入变量分别创建一个直方图。我们可以看到,这些变量有不同的取值范围,而且大多数变量都是指数分布的,例如,大多数情况下变量只占据直方图的一列,而其他情况下则留下一个长尾,而最后一个变量则似乎具有双峰分布。

根据我们选择的算法,将数据分布缩放到相同的取值范围是可能是很有用的,也许还需要使用一些幂变换,这将在后文进行讨论。

image.png

数据分布

我们还可以为每对输入变量创建一个散点图,称为散点图矩阵。这有助于我们了解是否有任何变量是相互关联的,或在同一方向上发生变化。我们还可以根据类标签给每个散点图上色。我们将多数类(没有癌症)标记为蓝点,少数类(癌症)标记为红点。

下面列出了完整的示例。


# create pairwise scatter plots of numeric input variablesfrom pandas import read_csvfrom pandas.plotting import scatter_matrixfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# define a mapping of class values to colorscolor_dict = {"'-1'":'blue', "'1'":'red'}# map each row to a color based on the class valuecolors = [color_dict[str(x)] for x in df.values[:, -1]]# pairwise scatter plots of all numerical variablesscatter_matrix(df, diagonal='kde', color=colors)pyplot.show()

运行该示例将创建一个6组*6组的散点图矩阵,用于六个输入变量的相互比较。矩阵的对角线表示每个变量的密度分布。

每一对变量的分布比较都出现了两次,分别位于主对角线元素的左侧与上侧(或右侧与下侧)。这提供了两种查看变量分布的尺度。

我们可以看到,对于正类与负类,许多变量的分布确实不同,这表明在癌症病例和非癌症病例之间进行合理的区分是可行的。

image.png

散点图矩阵

现在我们已经回顾了数据集,接下来让我们来评估与测试备选模型。

目录
相关文章
|
29天前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
96 4
|
29天前
|
人工智能 JSON 算法
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
阿里云的人工智能平台 PAI,作为一站式、 AI Native 的大模型与 AIGC 工程平台,为开发者和企业客户提供了 Qwen2.5-Coder 系列模型的全链路最佳实践。本文以Qwen2.5-Coder-32B为例,详细介绍在 PAI-QuickStart 完成 Qwen2.5-Coder 的训练、评测和快速部署。
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
|
13天前
|
编解码 机器人 测试技术
技术实践 | 使用 PAI+LLaMA Factory 微调 Qwen2-VL 模型快速搭建专业领域知识问答机器人
Qwen2-VL是一款具备高级图像和视频理解能力的多模态模型,支持多种语言,适用于多模态应用开发。通过PAI和LLaMA Factory框架,用户可以轻松微调Qwen2-VL模型,快速构建文旅领域的知识问答机器人。本教程详细介绍了从模型部署、微调到对话测试的全过程,帮助开发者高效实现定制化多模态应用。
|
1月前
|
机器学习/深度学习 PyTorch API
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
Transformer架构自2017年被Vaswani等人提出以来,凭借其核心的注意力机制,已成为AI领域的重大突破。该机制允许模型根据任务需求灵活聚焦于输入的不同部分,极大地增强了对复杂语言和结构的理解能力。起初主要应用于自然语言处理,Transformer迅速扩展至语音识别、计算机视觉等多领域,展现出强大的跨学科应用潜力。然而,随着模型规模的增长,注意力层的高计算复杂度成为发展瓶颈。为此,本文探讨了在PyTorch生态系统中优化注意力层的各种技术,
64 6
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
|
22天前
|
机器学习/深度学习 人工智能 算法
人工智能浪潮下的编程实践:构建你的第一个机器学习模型
在人工智能的巨浪中,每个人都有机会成为弄潮儿。本文将带你一探究竟,从零基础开始,用最易懂的语言和步骤,教你如何构建属于自己的第一个机器学习模型。不需要复杂的数学公式,也不必担心编程难题,只需跟随我们的步伐,一起探索这个充满魔力的AI世界。
39 12
|
29天前
|
机器学习/深度学习 Python
机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况
本文介绍了机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况,而ROC曲线则通过假正率和真正率评估二分类模型性能。文章还提供了Python中的具体实现示例,展示了如何计算和使用这两种工具来评估模型。
51 8
|
29天前
|
机器学习/深度学习 Python
机器学习中模型选择和优化的关键技术——交叉验证与网格搜索
本文深入探讨了机器学习中模型选择和优化的关键技术——交叉验证与网格搜索。介绍了K折交叉验证、留一交叉验证等方法,以及网格搜索的原理和步骤,展示了如何结合两者在Python中实现模型参数的优化,并强调了使用时需注意的计算成本、过拟合风险等问题。
50 6
|
1月前
|
机器学习/深度学习 数据采集 算法
从零到一:构建高效机器学习模型的旅程####
在探索技术深度与广度的征途中,我深刻体会到技术创新既在于理论的飞跃,更在于实践的积累。本文将通过一个具体案例,分享我在构建高效机器学习模型过程中的实战经验,包括数据预处理、特征工程、模型选择与优化等关键环节,旨在为读者提供一个从零开始构建并优化机器学习模型的实用指南。 ####
|
1月前
|
人工智能 边缘计算 JSON
DistilQwen2 蒸馏小模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
本文详细介绍在 PAI 平台使用 DistilQwen2 蒸馏小模型的全链路最佳实践。
|
1月前
|
机器学习/深度学习 人工智能 算法
探索机器学习中的线性回归模型
本文深入探讨了机器学习中广泛使用的线性回归模型,从其基本概念和数学原理出发,逐步引导读者理解模型的构建、训练及评估过程。通过实例分析与代码演示,本文旨在为初学者提供一个清晰的学习路径,帮助他们在实践中更好地应用线性回归模型解决实际问题。

热门文章

最新文章

相关产品

  • 人工智能平台 PAI