机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

简介: 机器学习中不平衡数据集分类模型示例:乳腺钼靶微钙化摄影数据集(一)

Deephub翻译组:Alexander Zhao

癌症检测是不平衡分类问题的一个普遍例子,因为非癌症病例往往比实际癌症病例多得多。

一个典型的不平衡分类数据集是乳腺摄影数据集,这个数据集用于从放射扫描中检测乳腺癌(特别是在乳腺摄影中出现明亮的微钙化簇)。研究人员通过扫描图像,对目标进行分割,然后用计算机视觉算法描述分割对象,从而获得了这一数据集。

由于类别不平衡十分严重,这是一个非常流行的不平衡分类数据集。其中98%的候选图像不是癌症,只有2%被有经验的放射科医生标记为癌症。

在本教程中,您将发现如何开发和评估乳腺癌钼靶摄影数据集的不平衡分类模型。完成本教程后,您将知道:

  • 如何加载和探索数据集,并从中获得预处理数据与选择模型的灵感。
  • 如何使用代价敏感算法评估一组机器学习模型并提高其性能。
  • 如何拟合最终模型并使用它预测特定情况下的类标签。

我们开始吧。

教程概述

本教程分为五个部分,分别是:

  1. 乳腺摄影数据集
  2. 浏览数据集
  3. 模型试验和基准结果
  4. 评估模型
  1. 评估机器学习算法
  2. 评估代价敏感算法
  1. 对新数据进行预测

乳腺摄影数据集

在这个项目中,我们将使用一个典型的不平衡机器学习数据集,即“乳腺摄影”数据集,有时称为“Woods乳腺摄影数据集”。

数据集归Kevin Woods等人所有,有关这项工作可以参考1993年发表的题为“Comparative Evaluation Of Pattern Recognition Techniques For Detection Of Microcalcifications In Mammography”的论文。

这个问题的焦点是通过放射扫描来检测乳腺癌,特别是在乳房X光片上出现的微小钙化团。

该数据集首先从24张已知癌症诊断结果的乳房X光片开始扫描,然后使用图像分割计算机视觉算法对图像进行预处理,从乳腺图像中提取候选目标。这些候选目标被分割后,就会被一位经验丰富的放射科医生手工标记。

Woods等人首先从分割的对象中选取与模式识别最相关的对象并进行特征提取,总共提取了29个特征,这些特征被减少到18个,然后最后减少到7个,具体如下(摘自论文原文):

  • 对象的面积(像素)
  • 对象的平均灰度
  • 对象周长像素的渐变强度
  • 对象中的均方根噪声波动
  • 对比度,也即对象的平均灰度减去对象周围两个像素宽边框的平均值
  • 基于形状描述子的低阶矩

这是一个二分类任务,目的是利用给定分割对象的特征来区分乳腺影片中的微钙化和非微钙化。

  • 非微钙化: 负类,或多数类
  • 微钙化: 正类,或少数类

作者评价与比较了一系列机器学习模型,包括人工神经网络、决策树、k近邻算法等。各个模型用受试者操作特性曲线(ROC)进行评估,并且用曲线下面积(AUC)进行比较。

选择ROC与AUC作为评估指标,其目的是最小化假阳性率(FPR,即特异性Specificity的补数)并最大化真阳性率(TPR,即敏感性Sensitivity),这也即ROC曲线的两轴。选用ROC曲线的另一个原因是可以让研究人员确定一个概率阈值,从而对可以接受的最大FPR与TPR进行权衡(因为随着TPR的上升FPR也不可避免地上升,译者注)。

研究结果表明,线性分类器(原文中应该是一个高斯朴素贝叶斯分类器)的AUC为0.936(100次实验的平均值),表现最好。

接下来,让我们仔细看看数据。

探索数据集

乳腺摄影数据集是一个广泛使用的标准机器学习数据集,用于探索和演示许多专门为不平衡分类设计的技术。一个典型的例子是流行的SMOTE技术。

我们使用的数据集是其中的一个版本,它与原始文件中描述的数据集有一些不同。

首先,下载数据集并将其保存在当前的工作目录中,名为“mammography.csv”

  • 下载乳房摄影数据集(maxomography.csv)

检查文件的内容。文件的前几行应如下所示:


0.23001961,5.0725783,-0.27606055,0.83244412,-0.37786573,0.4803223,'-1'
0.15549112,-0.16939038,0.67065219,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.78441482,-0.44365372,5.6747053,-0.85955255,-0.37786573,-0.94572324,'-1'
0.54608818,0.13141457,-0.45638679,-0.85955255,-0.37786573,-0.94572324,'-1'
-0.10298725,-0.3949941,-0.14081588,0.97970269,-0.37786573,1.0135658,'-1'

我们可以看到数据集有6个输入变量,而不是7个输入变量。有可能从这个版本的数据集中删除了论文中列出的第一个输入变量(用像素描述的对象面积)。

输入变量是数值类型,而目标变量是多数类置为“-1”、少数类置为“1”的字符串。这些值需要分别编码为0和1,以满足分类算法对二进制不平衡分类问题的期望。

可以使用read_csv()这一Pandas函数将数据集加载为DataFrame数据结构,注意指定header=None


...# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)

载入完毕后,我们调用DataFrame的shape方法打印其行列数。


...# summarize the shape of the datasetprint(dataframe.shape)

我们还可以通过使用Counter来确认数据,获取各类的比例。


...# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

把这两步放在一起,完整的载入与确认数据的代码如下:


# load and summarize the datasetfrom pandas import read_csvfrom collections import Counter# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedataframe = read_csv(filename, header=None)# summarize the shape of the datasetprint(dataframe.shape)# summarize the class distributiontarget = dataframe.values[:,-1]counter = Counter(target)for k,v in counter.items():  per = v / len(target) * 100 print('Class=%s, Count=%d, Percentage=%.3f%%' % (k, v, per))

运行示例代码,首先加载数据集并确认行和列的数量,即11183行、6个输入变量和1个目标变量。

然后确认类别分布,我们会观察到严重的类别不平衡:多数类(无癌症)约占98%,少数类(癌症)约占2%。


(11183, 7)Class='-1', Count=10923, Percentage=97.675%Class='1', Count=260, Percentage=2.325%

从反例和正例的比例来看,数据集似乎与SMOTE论文中描述的数据集基本匹配。

A typical mammography dataset might contain 98% normal pixels and 2% abnormal pixels.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

此外,正反例的数量也基本与论文相符。

The experiments were conducted on the mammography dataset. There were 10923 examples in the majority class and 260 examples in the minority class originally.

— SMOTE: Synthetic Minority Over-sampling Technique, 2002.

我相信这是同一个数据集,尽管我无法解释输入特征数量的不匹配现象,例如我们的数据集中只有6个输入数据,而原始论文中有7个。

我们还可以为每个变量创建直方图来观察输入变量的分布,下面列出了完整的示例。


# create histograms of numeric input variablesfrom pandas import read_csvfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# histograms of all variablesdf.hist()pyplot.show()

运行该示例代码将为数据集中的六个输入变量分别创建一个直方图。我们可以看到,这些变量有不同的取值范围,而且大多数变量都是指数分布的,例如,大多数情况下变量只占据直方图的一列,而其他情况下则留下一个长尾,而最后一个变量则似乎具有双峰分布。

根据我们选择的算法,将数据分布缩放到相同的取值范围是可能是很有用的,也许还需要使用一些幂变换,这将在后文进行讨论。

image.png

数据分布

我们还可以为每对输入变量创建一个散点图,称为散点图矩阵。这有助于我们了解是否有任何变量是相互关联的,或在同一方向上发生变化。我们还可以根据类标签给每个散点图上色。我们将多数类(没有癌症)标记为蓝点,少数类(癌症)标记为红点。

下面列出了完整的示例。


# create pairwise scatter plots of numeric input variablesfrom pandas import read_csvfrom pandas.plotting import scatter_matrixfrom matplotlib import pyplot# define the dataset locationfilename = 'mammography.csv'# load the csv file as a data framedf = read_csv(filename, header=None)# define a mapping of class values to colorscolor_dict = {"'-1'":'blue', "'1'":'red'}# map each row to a color based on the class valuecolors = [color_dict[str(x)] for x in df.values[:, -1]]# pairwise scatter plots of all numerical variablesscatter_matrix(df, diagonal='kde', color=colors)pyplot.show()

运行该示例将创建一个6组*6组的散点图矩阵,用于六个输入变量的相互比较。矩阵的对角线表示每个变量的密度分布。

每一对变量的分布比较都出现了两次,分别位于主对角线元素的左侧与上侧(或右侧与下侧)。这提供了两种查看变量分布的尺度。

我们可以看到,对于正类与负类,许多变量的分布确实不同,这表明在癌症病例和非癌症病例之间进行合理的区分是可行的。

image.png

散点图矩阵

现在我们已经回顾了数据集,接下来让我们来评估与测试备选模型。

目录
相关文章
|
4月前
|
人工智能 自然语言处理 IDE
模型微调不再被代码难住!PAI和Qwen3-Coder加速AI开发新体验
通义千问 AI 编程大模型 Qwen3-Coder 正式开源,阿里云人工智能平台 PAI 支持云上一键部署 Qwen3-Coder 模型,并可在交互式建模环境中使用 Qwen3-Coder 模型。
932 109
|
5月前
|
人工智能 自然语言处理 运维
【新模型速递】PAI-Model Gallery云上一键部署Kimi K2模型
月之暗面发布开源模型Kimi K2,采用MoE架构,参数达1T,激活参数32B,具备强代码能力及Agent任务处理优势。在编程、工具调用、数学推理测试中表现优异。阿里云PAI-Model Gallery已支持云端部署,提供企业级方案。
363 0
【新模型速递】PAI-Model Gallery云上一键部署Kimi K2模型
|
8月前
|
人工智能 JSON 算法
【解决方案】DistilQwen2.5-DS3-0324蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践
DistilQwen 系列是阿里云人工智能平台 PAI 推出的蒸馏语言模型系列,包括 DistilQwen2、DistilQwen2.5、DistilQwen2.5-R1 等。本文详细介绍DistilQwen2.5-DS3-0324蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践。
|
6月前
|
机器学习/深度学习 算法 安全
差分隐私机器学习:通过添加噪声让模型更安全,也更智能
本文探讨在敏感数据上应用差分隐私(DP)进行机器学习的挑战与实践。通过模拟DP-SGD算法,在模型训练中注入噪声以保护个人隐私。实验表明,该方法在保持71%准确率和0.79 AUC的同时,具备良好泛化能力,但也带来少数类预测精度下降的问题。研究强调差分隐私应作为模型设计的核心考量,而非事后补救,并提出在参数调优、扰动策略选择和隐私预算管理等方面的优化路径。
483 3
差分隐私机器学习:通过添加噪声让模型更安全,也更智能
|
5月前
|
人工智能 自然语言处理 运维
【新模型速递】PAI-Model Gallery云上一键部署gpt-oss系列模型
阿里云 PAI-Model Gallery 已同步接入 gpt-oss 系列模型,提供企业级部署方案。
|
6月前
|
机器学习/深度学习 人工智能 算法
Post-Training on PAI (4):模型微调SFT、DPO、GRPO
阿里云人工智能平台 PAI 提供了完整的模型微调产品能力,支持 监督微调(SFT)、偏好对齐(DPO)、强化学习微调(GRPO) 等业界常用模型微调训练方式。根据客户需求及代码能力层级,分别提供了 PAI-Model Gallery 一键微调、PAI-DSW Notebook 编程微调、PAI-DLC 容器化任务微调的全套产品功能。
|
7月前
|
存储 人工智能 运维
企业级MLOps落地:基于PAI-Studio构建自动化模型迭代流水线
本文深入解析MLOps落地的核心挑战与解决方案,涵盖技术断层分析、PAI-Studio平台选型、自动化流水线设计及实战构建,全面提升模型迭代效率与稳定性。
308 6
|
6月前
|
机器学习/深度学习 分布式计算 Java
Java 大视界 -- Java 大数据机器学习模型在遥感图像土地利用分类中的优化与应用(199)
本文探讨了Java大数据与机器学习模型在遥感图像土地利用分类中的优化与应用。面对传统方法效率低、精度差的问题,结合Hadoop、Spark与深度学习框架,实现了高效、精准的分类。通过实际案例展示了Java在数据处理、模型融合与参数调优中的强大能力,推动遥感图像分类迈向新高度。
|
6月前
|
机器学习/深度学习 存储 Java
Java 大视界 -- Java 大数据机器学习模型在游戏用户行为分析与游戏平衡优化中的应用(190)
本文探讨了Java大数据与机器学习模型在游戏用户行为分析及游戏平衡优化中的应用。通过数据采集、预处理与聚类分析,开发者可深入洞察玩家行为特征,构建个性化运营策略。同时,利用回归模型优化游戏数值与付费机制,提升游戏公平性与用户体验。

相关产品

  • 人工智能平台 PAI