不平衡数据集的建模的技巧和策略

简介: 不平衡数据集是指一个类中的示例数量与另一类中的示例数量显著不同的情况。 例如在一个二元分类问题中,一个类只占总样本的一小部分,这被称为不平衡数据集。类不平衡会在构建机器学习模型时导致很多问题。

不平衡数据集的主要问题之一是模型可能会偏向多数类,从而导致预测少数类的性能不佳。 这是因为模型经过训练以最小化错误率,并且当多数类被过度代表时,模型倾向于更频繁地预测多数类。 这会导致更高的准确率得分,但少数类别得分较低。

另一个问题是,当模型暴露于新的、看不见的数据时,它可能无法很好地泛化。 这是因为该模型是在倾斜的数据集上训练的,可能无法处理测试数据中的不平衡。

在本文中,我们将讨论处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。 将涵盖的一些技术包括重采样技术、代价敏感学习、使用适当的性能指标、集成方法和其他策略。 通过这些技巧,可以为不平衡的数据集构建有效的模型。

处理不平衡数据集的技巧

重采样技术是处理不平衡数据集的最流行方法之一。 这些技术涉及减少多数类中的示例数量或增加少数类中的示例数量。

欠采样可以从多数类中随机删除示例以减小其大小并平衡数据集。 这种技术简单易行,但会导致信息丢失,因为它会丢弃一些多数类示例。

过采样与欠采样相反,过采样随机复制少数类中的示例以增加其大小。 这种技术可能会导致过度拟合,因为模型是在少数类的重复示例上训练的。

SMOTE是一种更高级的技术,它创建少数类的合成示例,而不是复制现有示例。 这种技术有助于在不引入重复项的情况下平衡数据集。

代价敏感学习(Cost-sensitive learning)是另一种可用于处理不平衡数据集的技术。 在这种方法中,不同的错误分类成本被分配给不同的类别。 这意味着与错误分类多数类示例相比,模型因错误分类少数类示例而受到更严重的惩罚。

在处理不平衡的数据集时,使用适当的性能指标也很重要。 准确性并不总是最好的指标,因为在处理不平衡的数据集时它可能会产生误导。 相反,使用 AUC-ROC等指标可以更好地指示模型性能。

集成方法,例如 bagging 和 boosting,也可以有效地对不平衡数据集进行建模。 这些方法结合了多个模型的预测以提高整体性能。 Bagging 涉及独立训练多个模型并对它们的预测进行平均,而 boosting 涉及按顺序训练多个模型,其中每个模型都试图纠正前一个模型的错误。

重采样技术、成本敏感学习、使用适当的性能指标和集成方法是一些技巧和策略,可以帮助处理不平衡的数据集并提高机器学习模型的性能。

在不平衡数据集上提高模型性能的策略

收集更多数据是在不平衡数据集上提高模型性能的最直接策略之一。 通过增加少数类中的示例数量,模型将有更多信息可供学习,并且不太可能偏向多数类。 当少数类中的示例数量非常少时,此策略特别有用。

生成合成样本是另一种可用于提高模型性能的策略。 合成样本是人工创建的样本,与少数类中的真实样本相似。 这些样本可以使用 SMOTE等技术生成,该技术通过在现有示例之间进行插值来创建合成示例。 生成合成样本有助于平衡数据集并为模型提供更多示例以供学习。

使用领域知识来关注重要样本也是一种可行的策略,通过识别数据集中信息量最大的示例来提高模型性能。 例如,如果我们正在处理医学数据集,可能知道某些症状或实验室结果更能表明某种疾病。 通过关注这些例子可以提高模型准确预测少数类的能力。

最后可以使用异常检测等高级技术来识别和关注少数类示例。 这些技术可用于识别与多数类不同且可能是少数类示例的示例。 这可以通过识别数据集中信息量最大的示例来帮助提高模型性能。

在收集更多数据、生成合成样本、使用领域知识专注于重要样本以及使用异常检测等先进技术是一些可用于提高模型在不平衡数据集上的性能的策略。 这些策略可以帮助平衡数据集,为模型提供更多示例以供学习,并识别数据集中信息量最大的示例。

不平衡数据集的练习

这里我们使用信用卡欺诈分类的数据集演示处理不平衡数据的方法

 importpandasaspd
 importnumpyasnp
 fromsklearn.preprocessingimportRobustScaler
 fromsklearn.linear_modelimportLogisticRegression
 fromsklearn.model_selectionimporttrain_test_split
 fromsklearn.metricsimportaccuracy_score
 fromsklearn.metricsimportconfusion_matrix, classification_report,f1_score,recall_score,roc_auc_score, roc_curve
 importmatplotlib.pyplotasplt
 importseabornassns
 frommatplotlibimportrc,rcParams
 importitertools
 
 importwarnings
 warnings.filterwarnings("ignore", category=DeprecationWarning) 
 warnings.filterwarnings("ignore", category=FutureWarning) 
 warnings.filterwarnings("ignore", category=UserWarning)

读取数据

 df=pd.read_csv("creditcard.csv")
 df.head()
 print("Number of observations : " ,len(df))
 print("Number of variables : ", len(df.columns))
 #Number of observations :  284807
 #Number of variables :  31

查看数据集信息

 df.info()
 <class'pandas.core.frame.DataFrame'>
 RangeIndex: 284807entries, 0to284806
 Datacolumns (total31columns):
 #   Column  Non-Null Count   Dtype  
 ---  ------  --------------   -----  
 0   Time    284807non-null  float64
 1   V1      284807non-null  float64
 2   V2      284807non-null  float64
 3   V3      284807non-null  float64
 4   V4      284807non-null  float64
 5   V5      284807non-null  float64
 6   V6      284807non-null  float64
 7   V7      284807non-null  float64
 8   V8      284807non-null  float64
 9   V9      284807non-null  float64
 10  V10     284807non-null  float64
 11  V11     284807non-null  float64
 12  V12     284807non-null  float64
 13  V13     284807non-null  float64
 14  V14     284807non-null  float64
 15  V15     284807non-null  float64
 16  V16     284807non-null  float64
 17  V17     284807non-null  float64
 18  V18     284807non-null  float64
 19  V19     284807non-null  float64
 20  V20     284807non-null  float64
 21  V21     284807non-null  float64
 22  V22     284807non-null  float64
 23  V23     284807non-null  float64
 24  V24     284807non-null  float64
 25  V25     284807non-null  float64
 26  V26     284807non-null  float64
 27  V27     284807non-null  float64
 28  V28     284807non-null  float64
 29  Amount  284807non-null  float64
 30  Class   284807non-null  int64  
 dtypes: float64(30), int64(1)
 memoryusage: 67.4MB

查看分类类别:

 f,ax=plt.subplots(1,2,figsize=(18,8))
 df['Class'].value_counts().plot.pie(explode=[0,0.1],autopct='%1.1f%%',ax=ax[0],shadow=True)
 ax[0].set_title('dağılım')
 ax[0].set_ylabel('')
 sns.countplot('Class',data=df,ax=ax[1])
 ax[1].set_title('Class')
 plt.show()

 rob_scaler=RobustScaler()
 df['Amount'] =rob_scaler.fit_transform(df['Amount'].values.reshape(-1,1))
 df['Time'] =rob_scaler.fit_transform(df['Time'].values.reshape(-1,1))
 df.head()

创建基类模型

 X=df.drop("Class", axis=1)
 y=df["Class"]
 X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.20, random_state=123456)
 model=LogisticRegression(random_state=123456)
 model.fit(X_train, y_train)
 y_pred=model.predict(X_test)
 accuracy=accuracy_score(y_test, y_pred)
 print("Accuracy: %.3f"%(accuracy))

我们创建的模型的准确率评分为0.999。我们可以说我们的模型很完美吗?

混淆矩阵是一个用来描述分类模型的真实值在测试数据上的性能的表。它包含4种不同的估计值和实际值的组合。

 defplot_confusion_matrix(cm, classes,
                           title='Confusion matrix',
                           cmap=plt.cm.Blues):
 
     plt.rcParams.update({'font.size': 19})
     plt.imshow(cm, interpolation='nearest', cmap=cmap)
     plt.title(title,fontdict={'size':'16'})
     plt.colorbar()
     tick_marks=np.arange(len(classes))
     plt.xticks(tick_marks, classes, rotation=45,fontsize=12,color="blue")
     plt.yticks(tick_marks, classes,fontsize=12,color="blue")
     rc('font', weight='bold')
     fmt='.1f'
     thresh=cm.max()
     fori, jinitertools.product(range(cm.shape[0]), range(cm.shape[1])):
         plt.text(j, i, format(cm[i, j], fmt),
                  horizontalalignment="center",
                  color="red")
 
     plt.ylabel('True label',fontdict={'size':'16'})
     plt.xlabel('Predicted label',fontdict={'size':'16'})
     plt.tight_layout()
 
 plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
                       title='Confusion matrix')

•非欺诈类共进行了56875次预测,其中56870次(TP)正确,5次(FP)错误。

•欺诈类共进行了87次预测,其中31次(FN)错误,56次(TN)正确。

该模型可以预测欺诈状态,准确率为0.99。但当检查混淆矩阵时,欺诈类的错误预测率相当高。也就是说该模型正确地预测了非欺诈类的概率为0.99。但是非欺诈类的观测值的数量高于欺诈类的观测值的数量,这拉搞了我们对准确率的计算,并且我们更加关注的是欺诈类的准确率,所以我们需要一个指标来衡量它的性能。

选择正确的指标

在处理不平衡数据集时,选择正确的指标来评估模型的性能非常重要。 传统指标,如准确性、精确度和召回率,可能不适用于不平衡的数据集,因为它们没有考虑数据中类别的分布。

经常用于不平衡数据集的一个指标是 F1 分数。 F1 分数是精确率和召回率的调和平均值,它提供了两个指标之间的平衡。 计算如下:

F1 = 2 (precision recall) / (precision + recall)

另一个经常用于不平衡数据集的指标是 AUC-ROC。 AUC-ROC 衡量模型区分正类和负类的能力。 它是通过绘制不同分类阈值下的TPR与FPR来计算的。 AUC-ROC 值的范围从 0.5(随机猜测)到 1.0(完美分类)。

 print(classification_report(y_test, y_pred))
 
                  precision   recall   f1-score   support
 
            0       1.00      1.00      1.00     56875
            1       0.92      0.64      0.76        87
 
     accuracy                           1.00     56962
    macroavg       0.96      0.82      0.88     56962
 weightedavg       1.00      1.00      1.00     56962

返回对0(非欺诈)类的预测有多少是正确的。查看混淆矩阵,56870 + 31 = 56901个非欺诈类预测,其中56870个预测正确。0类的精度值接近1 (56870 / 56901)

返回对1 (欺诈)类的预测有多少是正确的。查看混淆矩阵,5 + 56 = 61个欺诈类别预测,其中56个被正确估计。0类的精度为0.92 (56 / 61),可以看到差别还是很大的

过采样

通过复制少数类样本来稳定数据集。

随机过采样:通过添加从少数群体中随机选择的样本来平衡数据集。如果数据集很小,可以使用这种技术。可能会导致过拟合。randomoverampler方法接受sampling_strategy参数,当sampling_strategy = ' minority '被调用时,它会增加minority类的数量,使其与majority类的数量相等。

我们可以在这个参数中输入一个浮点值。例如,假设我们的少数群体人数为1000人,多数群体人数为100人。如果我们说sampling_strategy = 0.5,少数类将被添加到500。

 y_train.value_counts()
 0    227440
 1       405
 Name: Class, dtype: int64
 
 fromimblearn.over_samplingimportRandomOverSampler
 oversample=RandomOverSampler(sampling_strategy='minority')
 X_randomover, y_randomover=oversample.fit_resample(X_train, y_train)

采样后训练

 model.fit(X_randomover, y_randomover)
 y_pred=model.predict(X_test)
 
 plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
                       title='Confusion matrix')

应用随机过采样后,训练模型的精度值为0.97,出现了下降。但是从混淆矩阵来看,模型的欺诈类的正确估计率有所提高。

SMOTE 过采样:从少数群体中随机选取一个样本。然后,为这个样本找到k个最近的邻居。从k个最近的邻居中随机选取一个,将其与从少数类中随机选取的样本组合在特征空间中形成线段,形成合成样本。

 from imblearn.over_sampling import SMOTE
 oversample = SMOTE()
 X_smote, y_smote = oversample.fit_resample(X_train, y_train)

使用SMOTE后的数据训练

 model.fit(X_smote, y_smote)
 y_pred = model.predict(X_test)
 
 accuracy = accuracy_score(y_test, y_pred)
 plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
                       title='Confusion matrix')

可以看到与基线模型相比,欺诈的准确率有所提高,但是比随机过采样有所下降,这可能是数据集的原因,因为SMOTE采样会生成心的数据,所以并不适合所有的数据集。

总结

在这篇文章中,我们讨论了处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。不平衡的数据集可能是机器学习中的一个常见问题,并可能导致在预测少数类时表现不佳。

本文介绍了一些可用于平衡数据集的重采样技术,如欠采样、过采样和SMOTE。还讨论了成本敏感学习和使用适当的性能指标,如AUC-ROC,这可以提供更好的模型性能指示。

处理不平衡的数据集是具有挑战性的,但通过遵循本文讨论的技巧和策略,可以建立有效的模型准确预测少数群体。重要的是要记住最佳方法将取决于特定的数据集和问题,为了获得最佳结果,可能需要结合各种技术。因此,试验不同的技术并使用适当的指标评估它们的性能是很重要的。

https://avoid.overfit.cn/post/774ca6891f26470093970c074afceede

作者:Emine Bozkuş

目录
相关文章
|
7月前
|
机器学习/深度学习 数据采集 监控
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
974 0
|
3月前
|
机器学习/深度学习 数据可视化 数据建模
使用ClassificationThresholdTuner进行二元和多类分类问题阈值调整,提高模型性能增强结果可解释性
在分类问题中,调整决策的概率阈值虽常被忽视,却是提升模型质量的有效步骤。本文深入探讨了阈值调整机制,尤其关注多类分类问题,并介绍了一个名为 ClassificationThresholdTuner 的开源工具,该工具自动化阈值调整和解释过程。通过可视化功能,数据科学家可以更好地理解最优阈值及其影响,尤其是在平衡假阳性和假阴性时。此外,工具支持多类分类,解决了传统方法中的不足。
52 2
使用ClassificationThresholdTuner进行二元和多类分类问题阈值调整,提高模型性能增强结果可解释性
|
2月前
|
机器学习/深度学习 算法 数据建模
【机器学习】类别不平衡数据的处理
【机器学习】类别不平衡数据的处理
|
4月前
|
机器学习/深度学习 算法
【机器学习】不同决策树的节点分裂准则(属性划分标准)
决策树的不同节点分裂准则,包括原始决策树的节点分裂准则、ID3算法的信息增益、C4.5算法的信息增益比以及CART算法的平方根误差最小化和基尼指数。
59 1
|
5月前
|
机器学习/深度学习 索引 Python
。这不仅可以减少过拟合的风险,还可以提高模型的准确性、降低计算成本,并帮助理解数据背后的真正含义。`sklearn.feature_selection`模块提供了多种特征选择方法,其中`SelectKBest`是一个元变换器,可以与任何评分函数一起使用来选择数据集中K个最好的特征。
。这不仅可以减少过拟合的风险,还可以提高模型的准确性、降低计算成本,并帮助理解数据背后的真正含义。`sklearn.feature_selection`模块提供了多种特征选择方法,其中`SelectKBest`是一个元变换器,可以与任何评分函数一起使用来选择数据集中K个最好的特征。
|
7月前
|
机器学习/深度学习 数据可视化
数据分享|R语言生存分析模型因果分析:非参数估计、IP加权风险模型、结构嵌套加速失效(AFT)模型分析流行病学随访研究数据
数据分享|R语言生存分析模型因果分析:非参数估计、IP加权风险模型、结构嵌套加速失效(AFT)模型分析流行病学随访研究数据
|
7月前
|
机器学习/深度学习 算法 数据挖掘
实战Scikit-Learn:处理不平衡数据集的策略
【4月更文挑战第17天】本文探讨了Scikit-Learn处理不平衡数据集的策略,包括重采样(过采样少数类如SMOTE,欠采样多数类如RandomUnderSampler)、修改损失函数(如加权损失函数)、使用集成学习(如随机森林、AdaBoost)以及选择合适的评估指标(精确率、召回率、F1分数)。这些方法有助于提升模型对少数类的预测性能和泛化能力。
|
7月前
线性回归前特征离散化可简化模型、增强稳定性、选有意义特征、降低过拟合、提升计算效率及捕捉非线性关系。
【5月更文挑战第2天】线性回归前特征离散化可简化模型、增强稳定性、选有意义特征、降低过拟合、提升计算效率及捕捉非线性关系。但过多离散特征可能增加复杂度,丢失信息,影响模型泛化和精度。需谨慎平衡离散化利弊。
56 0
|
7月前
|
机器学习/深度学习 存储 编解码
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性
89 0
重参架构的量化问题解决了 | 粗+细粒度权重划分量化让RepVGG-A1仅损失0.3%准确性
|
机器学习/深度学习 监控 搜索推荐
深度粗排模型的GMV优化实践:基于全空间-子空间联合建模的蒸馏校准模型
随着业务的不断发展,粗排模型在整个系统链路中变得越来越重要,能够显著提升线上效果。本文是对粗排模型优化的阶段性总结。
1606 0
深度粗排模型的GMV优化实践:基于全空间-子空间联合建模的蒸馏校准模型