机器学习技术:如何使用交叉验证和ROC曲线提高疾病预测的准确性和效率?

简介: ROC曲线则是一种可视化评估分类算法表现的图形呈现方法,用于绘制二分类模型的真阳性率和假阳性率之间的关系曲线。这种方法常用于比较不同分类器和优化分类器的性能。接下来的部分将详细介绍如何使用这两种方法,并提供实际案例和技术实践以及最佳实践建议。

一、引言



随着机器学习的普及,评估模型的性能越来越重要。交叉验证和ROC曲线是两种常见的评估模型性能的方法。本文将介绍这两种方法的基本原理和应用场景,并结合实际案例和技术实践,讲解如何使用交叉验证和ROC曲线来提高机器学习模型的性能。此外,文章也将提供一些最佳实践建议,以帮助读者在实际工作中应用这些方法。


交叉验证是一种常用于评估机器学习模型性能的方法。它将数据集分成k个子集,每次取其中的k-1个子集作为训练集,剩下的一个作为测试集,重复k次后得到k个模型和性能指标,并对这些结果进行平均。这种方法能够充分利用数据集,并且评估结果更加稳定可靠。


ROC曲线则是一种可视化评估分类算法表现的图形呈现方法,用于绘制二分类模型的真阳性率和假阳性率之间的关系曲线。这种方法常用于比较不同分类器和优化分类器的性能。


接下来的部分将详细介绍如何使用这两种方法,并提供实际案例和技术实践以及最佳实践建议。


二、交叉验证



交叉验证是一种常用于评估机器学习模型性能的方法,其基本原理如下:


  • 将数据集分成k个子集。
  • 使用其中k-1个子集作为训练集,剩下的一个子集作为验证集。
  • 用训练集训练模型,并使用验证集验证模型性能。
  • 重复上述步骤k次,每次使用不同的子集作为验证集。
  • 计算k次结果的平均值作为最终结果。


如何使用交叉验证评估模型的性能:


  • 训练模型
  • 使用交叉验证方法评估模型性能
  • 分析模型性能结果


交叉验证的常用算法包括:随机划分、分层采样、分组采样等,不同算法在不同数据集和问题上表现不同,需要选择合适的算法。


常见的交叉验证评估指标包括:准确率、F1值、召回率和精度等。在选择交叉验证比较算法时,需要根据不同的评估指标和问题类型选择最合适的算法。通过交叉验证,我们能够更加准确地评估机器学习模型的性能,从而选择最优的模型或算法。从k折到自助法:常用交叉验证方法的优缺点


三、ROC曲线



ROC曲线是一种可视化评估分类算法表现的图形呈现方法,用于绘制不同阈值下真阳性率和假阳性率之间的关系曲线。


3.1 ROC曲线的原理和定义


  • 真阳性率(True Positive Rate, TPR):真实类别为正例的样本中,被分类器成功分类为正例的比例。
  • 假阳性率(False Positive Rate, FPR):真实类别为负例的样本中,被错误分类为正例的比例。
  • ROC曲线:在真阳性率与假阳性率之间进行折中,故可以通过绘制真阳性率/假阳性率相对于不同阈值的关系曲线来评估分类器的性能。


3.2 如何使用ROC曲线评估模型的性能


绘制ROC曲线,根据曲线下的面积(AUC),来评估分类器的性能。


常见问题:

  • 分类不平衡导致的ROC曲线“偏移”:当负例远多于正例时,模型偏向于将样本全部预测为负例,此时TPR和FPR的比值较小,ROC曲线出现偏移或在图像左下角,此时AUC不能准确的评估模型性能,需要结合其他指标如准确率、召回率等进行综合分析。
  • 分类器性能对阈值的依赖:通常,分类器的性能与判定的阈值有直接关系,因此相同的分类器在不同阈值下,性能可能存在差异。在绘制ROC曲线的时候,需要绘制不同阈值下的曲线,以评估在不同阈值下模型的性能。


四、如何结合交叉验证和ROC曲线评估模型性能?



如何比较不同机器学习模型的性能?

  • 对于多个分类器,可以使用交叉验证和ROC曲线分别评估每个分类器的性能。
  • 比较不同分类器在不同阈值下的ROC曲线,选择AUC值最大的分类器。
  • 比较不同分类器在交叉验证后得到的平均性能指标,选择表现最好的分类器。


如何对机器学习模型进行参数调优和最优选择?

  • 使用网格搜索法或随机搜索法,寻找最优参数组合。
  • 比较不同参数组合下的分类器性能,选择表现最好的分类器。
  • 对比不同分类器的性能后,使用最优参数组合重新训练模型,得到分类器的最终版本。


接下来的部分将结合案例和实践,介绍如何使用交叉验证和ROC曲线综合评估分类器的性能,以及如何对机器学习模型进行参数调优和最优选择。


五、实际案例和技术实践



我们以鸢尾花数据集为例,介绍如何使用交叉验证和ROC曲线评估分类器的性能,并对分类器进行参数调优和最优选择。


5.1 gbsg数据集介绍


gbsg 是 survival 包中的一个数据集,它包含了与乳腺癌患者相关的一些特征和生存信息。该数据集由德国乳腺癌研究组(German Breast Cancer Study Group)收集。


install.packages("randomizr")
library(survival)
library(randomizr)
library(glmnet)
library(pROC)
head(gbsg)


结果展示:


> head(gbsg)
   pid age meno size grade nodes pgr er hormon rfstime status
1  132  49    0   18     2     2   0  0      0    1838      0
2 1575  55    1   20     3    16   0  0      0     403      1
3 1140  56    1   40     3     3   0  0      0    1603      0
4  769  45    0   25     3     1   0  4      0     177      0
5  130  65    1   30     2     5   0 36      1    1855      0
6 1642  48    0   52     2    11   0  0      0     842      1


5.2 实践步骤


  • 步骤1:加载数据集并拆分为训练集和测试集。
  • 步骤2:使用逻辑回归模型
  • 步骤3:使用决策树模型
  • 步骤4:在测试集上进行预测
  • 步骤5:计算auc
  • 步骤6:绘制ROC曲线
  • 步骤7:交叉验证


5.3 代码实现


  • 步骤1:加载数据集并拆分为训练集和测试集
set.seed(1234)
# 删除pid列
data <- subset(gbsg, select = -c(pid))
trainIndex <- sample(1:nrow(gbsg), 0.8 * nrow(gbsg))
train <- gbsg[trainIndex,]
test <- gbsg[-trainIndex,]
nrow(train)
nrow(test)


结果展示:

[1] 548
[1] 138


  • 步骤2:使用逻辑回归模型
# 创建并训练逻辑回归模型
lr_model <- glm(status ~ ., data = train, family = binomial(link = "logit"))
summary(lr_model)


结果展示:

Call:
glm(formula = status ~ ., family = binomial(link = "logit"), 
    data = train)
Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-2.4692  -0.9033  -0.4523   0.9058   2.2232  
Coefficients:
              Estimate Std. Error z value Pr(>|z|)    
(Intercept)  1.1771566  0.8549956   1.377   0.1686    
age         -0.0067695  0.0157300  -0.430   0.6669    
meno         0.4075039  0.3139902   1.298   0.1943    
size         0.0077945  0.0074443   1.047   0.2951    
grade        0.0271568  0.1772156   0.153   0.8782    
nodes        0.0482255  0.0209566   2.301   0.0214 *  
pgr         -0.0018286  0.0006756  -2.707   0.0068 ** 
er           0.0007048  0.0007740   0.911   0.3625    
hormon      -0.2889779  0.2220921  -1.301   0.1932    
rfstime     -0.0014294  0.0001740  -8.214   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
    Null deviance: 755.12  on 547  degrees of freedom
Residual deviance: 618.96  on 538  degrees of freedom
AIC: 638.96
Number of Fisher Scoring iterations: 4


  • 步骤3:使用决策树模型
library(rpart)
dt_model <- rpart(status ~ ., data = train, method = "class")
summary(dt_model)


结果展示:


Call:
rpart(formula = status ~ ., data = train, method = "class")
  n= 548 
          CP nsplit rel error    xerror       xstd
1 0.36546185      0 1.0000000 1.0000000 0.04681075
2 0.04417671      1 0.6345382 0.6947791 0.04369674
3 0.02208835      2 0.5903614 0.6385542 0.04266618
4 0.02008032      4 0.5461847 0.6586345 0.04305245
5 0.01070950      6 0.5060241 0.6746988 0.04334678
6 0.01004016     10 0.4618474 0.6867470 0.04355912
7 0.01000000     12 0.4417671 0.6867470 0.04355912
Variable importance
rfstime   nodes     pgr   grade     age      er    size 
     53      14       9       7       6       6       4 
Node number 1: 548 observations,    complexity param=0.3654618
  predicted class=0  expected loss=0.4543796  P(node) =1
    class counts:   299   249
   probabilities: 0.546 0.454 
  left son=2 (295 obs) right son=3 (253 obs)
  Primary splits:
      rfstime < 966.5  to the right, improve=47.781200, (0 missing)
      nodes   < 3.5    to the left,  improve=19.571190, (0 missing)
      pgr     < 103.5  to the right, improve=12.617530, (0 missing)
      er      < 14.5   to the right, improve= 4.926028, (0 missing)
      grade   < 1.5    to the left,  improve= 4.652411, (0 missing)
  Surrogate splits:
      nodes < 5.5    to the left,  agree=0.637, adj=0.213, (0 split)
      pgr   < 25.5   to the right, agree=0.626, adj=0.190, (0 split)
      er    < 8.5    to the right, agree=0.588, adj=0.107, (0 split)
      grade < 2.5    to the left,  agree=0.569, adj=0.067, (0 split)
      age   < 44.5   to the right, agree=0.555, adj=0.036, (0 split)
Node number 2: 295 observations,    complexity param=0.0107095
  predicted class=0  expected loss=0.2610169  P(node) =0.5383212
    class counts:   218    77
   probabilities: 0.739 0.261 
  left son=4 (155 obs) right son=5 (140 obs)
  Primary splits:
      rfstime < 1593.5 to the right, improve=7.364219, (0 missing)
      pgr     < 123    to the right, improve=2.775756, (0 missing)
      nodes   < 3.5    to the left,  improve=2.634844, (0 missing)
      age     < 68.5   to the left,  improve=2.093624, (0 missing)
      size    < 14     to the left,  improve=1.649638, (0 missing)
  Surrogate splits:
      size  < 24.5   to the left,  agree=0.593, adj=0.143, (0 split)
      pgr   < 26.5   to the right, agree=0.553, adj=0.057, (0 split)
      age   < 37.5   to the right, agree=0.549, adj=0.050, (0 split)
      grade < 2.5    to the left,  agree=0.549, adj=0.050, (0 split)
      nodes < 4.5    to the left,  agree=0.549, adj=0.050, (0 split)
...


  • 步骤4:在测试集上进行预测
lr_pred <- predict(lr_model, newdata = test, type = "response")
dt_pred <- predict(dt_model, newdata = test, type = "class")
dt_pred <- as.numeric(as.character(dt_pred))


  • 步骤5:计算auc
# 计算逻辑回归模型的ROC曲线
lr_roc <- roc(test$status, lr_pred)
lr_auc <- auc(lr_roc)
lr_auc
# 计算决策树模型的ROC曲线
dt_roc <- roc(test$status, dt_pred_class)
dt_auc <- auc(dt_roc)
dt_auc


结果展示:


# > lr_roc
Call:
roc.default(response = test$status, predictor = lr_pred)
Data: lr_pred in 88 controls (test$status 0) < 50 cases (test$status 1).
Area under the curve: 0.8352
# > dt_roc
Call:
roc.default(response = test$status, predictor = dt_pred_class)
Data: dt_pred_class in 88 controls (test$status 0) < 50 cases (test$status 1).
Area under the curve: 0.715


  • 步骤6:绘制ROC曲线
plot(lr_roc, col = "blue", main = "ROC Curve", xlab = "False Positive Rate", ylab = "True Positive Rate", print.thres = TRUE, print.auc = TRUE, legacy.axes = TRUE)
lines(dt_roc, col = "red")
legend("bottomright", legend = c(paste("Logistic Regression (AUC =", round(lr_auc, 2), ")"), paste("Decision Tree (AUC =", round(dt_auc, 2), ")")), col = c("blue", "red"), lty = 1)


640.png

从ROC曲线和auc值看,逻辑回归预测的准确率是要高于决策树的,相对更适合该数据集。


  • 步骤7:交叉验证
library(caret)
# 交叉验证
k = 5
control <- trainControl(method = "cv", number = k)
# 使用train函数逻辑回归模型进行交叉验证
result_lr <- train(status ~ ., data = train, method = "glm",trControl = control) 
# 查看逻辑回归交叉验证的结果
result_lr
# 使用train函数决策树模型进行交叉验证
result_repart <- train(status ~ ., data = train, method = "rpart",trControl = control) 
# 查看决策树交叉验证的结果
result_repart


结果展示:


# > result_lr
Generalized Linear Model 
548 samples
  9 predictor
No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 439, 438, 439, 438, 438 
Resampling results:
  RMSE       Rsquared   MAE     
  0.4486714  0.2001741  0.396397
# > result_repart
CART 
548 samples
  9 predictor
No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 438, 439, 439, 438, 438 
Resampling results across tuning parameters:
  cp          RMSE       Rsquared   MAE      
  0.03500388  0.4539174  0.1900666  0.3885073
  0.03912224  0.4539174  0.1900666  0.3885073
  0.17584786  0.4855547  0.1045128  0.4554335
RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.03912224.


目录
相关文章
|
8月前
|
机器学习/深度学习 数据可视化 计算机视觉
【视频】机器学习交叉验证CV原理及R语言主成分PCA回归分析犯罪率|数据共享
【视频】机器学习交叉验证CV原理及R语言主成分PCA回归分析犯罪率|数据共享
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
探索机器学习中的自然语言处理技术
【2月更文挑战第31天】 随着人工智能的飞速发展,自然语言处理(NLP)技术在机器学习领域扮演着越来越重要的角色。本文旨在深入探讨NLP的关键技术,包括语言模型、词嵌入和深度学习方法,并分析这些技术如何相互协作,以实现更高效的文本分析和理解。通过案例研究和最新研究成果的介绍,我们展示了NLP在实际应用中的强大潜力,以及它如何推动人机交互和信息检索系统的革新。
208 0
|
2月前
|
机器学习/深度学习 Python
机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况
本文介绍了机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况,而ROC曲线则通过假正率和真正率评估二分类模型性能。文章还提供了Python中的具体实现示例,展示了如何计算和使用这两种工具来评估模型。
69 8
|
2月前
|
机器学习/深度学习 数据采集 算法
机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用
医疗诊断是医学的核心,其准确性和效率至关重要。本文探讨了机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用。文章还讨论了Python在构建机器学习模型中的作用,面临的挑战及应对策略,并展望了未来的发展趋势。
154 1
|
3月前
|
机器学习/深度学习 计算机视觉 Python
模型预测笔记(三):通过交叉验证网格搜索机器学习的最优参数
本文介绍了网格搜索(Grid Search)在机器学习中用于优化模型超参数的方法,包括定义超参数范围、创建参数网格、选择评估指标、构建模型和交叉验证策略、执行网格搜索、选择最佳超参数组合,并使用这些参数重新训练模型。文中还讨论了GridSearchCV的参数和不同机器学习问题适用的评分指标。最后提供了使用决策树分类器进行网格搜索的Python代码示例。
180 1
|
8月前
|
机器学习/深度学习 人工智能 运维
【人工智能技术专题】「入门到精通系列教程」打好AI基础带你进军人工智能领域的全流程技术体系(机器学习知识导论)(二)
【人工智能技术专题】「入门到精通系列教程」打好AI基础带你进军人工智能领域的全流程技术体系(机器学习知识导论)
317 1
|
8月前
|
机器学习/深度学习 人工智能 自然语言处理
【人工智能技术专题】「入门到精通系列教程」打好AI基础带你进军人工智能领域的全流程技术体系(机器学习知识导论)(一)
【人工智能技术专题】「入门到精通系列教程」打好AI基础带你进军人工智能领域的全流程技术体系(机器学习知识导论)
393 1
|
6月前
|
机器学习/深度学习 人工智能
8个特征工程技巧提升机器学习预测准确性
8个特征工程技巧提升机器学习预测准确性
122 6
8个特征工程技巧提升机器学习预测准确性
|
5月前
|
机器学习/深度学习 开发者 Python
Python 与 R 在机器学习入门中的学习曲线差异
【8月更文第6天】在机器学习领域,Python 和 R 是两种非常流行的编程语言。Python 以其简洁的语法和广泛的社区支持著称,而 R 则以其强大的统计功能和数据分析能力受到青睐。本文将探讨这两种语言在机器学习入门阶段的学习曲线差异,并通过构建一个简单的线性回归模型来比较它们的体验。
76 7
|
5月前
|
机器学习/深度学习 索引