五种交叉验证方法的 R 语言实现

简介: 五种交叉验证方法的 R 语言实现

引言

交叉验证(Cross-Validation)是一种常用的模型评估方法,旨在评估机器学习模型在未知数据上的性能。它通过将数据集划分为多个互斥的子集,然后进行多次训练和测试,从而更全面地评估模型的泛化能力。

交叉验证方法

在机器学习中,常用的交叉验证方法有以下几种:

  • K 折交叉验证(K-Fold Cross Validation)
  • 留一交叉验证(Leave-One-Out Cross Validation,简称LOOCV)
  • 留 P 交叉验证(Leave-P-Out Cross Validation,简称LPOCV)
  • 自助法交叉验证(Bootstrap Cross Validation)
  • 分层 K 折交叉验证(Stratified K-Fold Cross Validation)

下面以线性回归作为目标模型,分别介绍这五种方法的原理及 R 语言实现。

K 折交叉验证

将原始数据随机分成 K 个子集(通常是等分),每次选择其中一个子集作为测试集,其他 K-1 个子集作为训练集,进行 K 次训练和验证,最后将 K 次验证结果平均得到最终的性能评估。K 折交叉验证是最常用的交叉验证方法。

K 折交叉验证的 R 语言实现需要使用caret包中的 trainControl()函数,设置参数method = "cv", number = 5表示五折交叉验证:

# 导入所需包
library(ggplot2)
library(lattice)
library(caret)
# 以鸢尾花数据集为例
data(iris)  
# 创建模型
set.seed(123)  
model <- train(Sepal.Length~.,data=iris, method = "lm", 
               trControl = trainControl(method = "cv", number = 5))
# 查看交叉验证结果
print(model)

留一交叉验证

将原始数据集中的一个样本作为测试集,其他样本作为训练集,进行 N次训练和验证(N 为样本数量),每个训练集都只包含一个样本。LOOCV 在样本数量较小的情况下比较常用,但计算成本较高。

LOOCV 的 R 语言实现只需要在上述代码的基础上,将 trainControl()函数改为 method = "LOOCV"

# 以鸢尾花数据集为例
data(iris)  
set.seed(123)  
# 创建模型
model <- train(Sepal.Length~.,data=iris, method = "lm", 
               trControl = trainControl(method = "LOOCV"))
# 查看交叉验证结果
print(model)

留 P 交叉验证

与留一交叉验证类似,但每次留 P 个样本作为测试集,其他样本作为训练集,进行多次训练和验证。LPOCV 是 LOOCV 的一种扩展形式,可以减小计算开销。

LPOCV 的 R 语言实现需要将 trainControl()函数改为 method = "LGOCV", p = 0.7,其中 p=0.7 表示 70% 的数据作为训练集 (注意:函数中的 p 和 LPOCV 中的 p含义不同,并且在 R 中用LGOCV来表示留 P 交叉验证):

# 以鸢尾花数据集为例
data(iris)  
set.seed(123)  
# 创建模型
model <- train(Sepal.Length~.,data=iris, method = "lm", 
               trControl = trainControl(method = "LGOCV", p = 0.7))
# 查看交叉验证结果
print(model)

自助法交叉验证

从原始数据集中有放回采样得到一个新的训练集,部分样本可能被重复采样,未被抽到的样本作为测试集。采样和验证的过程多次进行,最后将多次验证结果的平均值作为性能评估。自助法交叉验证适用于数据集较小,且难以有效划分训练集和测试集的情况。

自助法交叉验证的 R 语言实现需要先创建样本索引 createDataPartition(iris$Sepal.Length, times = 5, p = 0.7)。其中,times表示抽样的重复次数,p表示训练集的比例:

# 以鸢尾花数据集为例
data(iris)  
set.seed(123)  
# 创建自助样本索引
# 创建自助样本索引
folds <- createDataPartition(iris$Sepal.Length, times = 5, p = 0.7)  
# 创建模型
model <- train(Sepal.Length~.,data=iris, method = "lm", 
               trControl = trainControl(method = "boot", index = folds))
# 查看交叉验证结果
print(model)

分层 K 折交叉验证

在进行 K 折交叉验证时,每个子集的样本类别比例都与原始数据集中保持一致。这种交叉验证方法常用于处理不平衡数据集的情况。

分层 K 折交叉验证需要先根据分类变量的比例来划分数据子集,然后再进行交叉验证:

library(caret)
##根据分类变量划分子集
folds <- createFolds(iris$Species, k = 5)
##检验每个子集中的类别百分比是否等于原始样本
prop.table(table(iris$Species[folds$Fold1]))
prop.table(table(iris$Species[folds$Fold2]))
prop.table(table(iris$Species[folds$Fold3]))
prop.table(table(iris$Species[folds$Fold4]))
prop.table(table(iris$Species[folds$Fold5]))
prop.table(table(iris$Species))
cv_r <- list()
RMSE <- list()
MAE <- list()
for (i in 1:length(folds)) {
  ##将一个折作为测试集
  test <- iris[folds[[i]],]
  ##其余折作为训练集
  train <- iris[-folds[[i]],]
  ##构建模型
  model <- lm(Sepal.Length~.,data=train)
  pred <- predict(model,newdata = data.frame(test))
  ##选取一些评价指标
  cv_r[[i]] <- summary(model)$r.squared
  RMSE[[i]] <- sqrt(mean((pred-test$Sepal.Length)^ 2))
  MAE[[i]] <- mean(abs(pred-test$Sepal.Length))
}
##每次迭代的交叉验证结果
cv_r_values <- unlist(cv_r)
cv_RMSE_values <- unlist(RMSE)
cv_MAE_values <- unlist(MAE)
##最终的结果
cv_mean_r <- mean(cv_r_values)
cv_mean_RMSE <- mean(cv_RMSE_values)
cv_mean_MAE <- mean(cv_MAE_values)

小结

交叉验证方法的选择取决于数据集的规模、性质以及具体的应用场景。注意,上述方法不适用于时间序列数据,因为时间序列具有时间依赖关系,不能随机选取训练集和测试集。 下一期再为大家分享时间序列交叉验证的方法。


目录
相关文章
|
6月前
|
移动开发 算法 数据可视化
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享(上)
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享
|
6月前
|
数据可视化 算法
【R语言实战】——kNN和朴素贝叶斯方法实战
【R语言实战】——kNN和朴素贝叶斯方法实战
|
6月前
|
数据可视化
R语言机器学习方法分析二手车价格影响因素
R语言机器学习方法分析二手车价格影响因素
|
6月前
|
移动开发 算法 数据可视化
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享(上)
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享
108 12
|
6月前
|
数据可视化 Python
R语言蒙特卡罗Monte Carlo方法进行数值积分和模拟可视化
R语言蒙特卡罗Monte Carlo方法进行数值积分和模拟可视化
|
6月前
|
机器学习/深度学习 算法 数据库
数据分享|R语言用核Fisher判别方法、支持向量机、决策树与随机森林研究客户流失情况
数据分享|R语言用核Fisher判别方法、支持向量机、决策树与随机森林研究客户流失情况
|
6月前
|
算法
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享(下)
【视频】马尔可夫链蒙特卡罗方法MCMC原理与R语言实现|数据分享
|
6月前
|
算法 数据可视化 Windows
R语言BUGS/JAGS贝叶斯分析: 马尔科夫链蒙特卡洛方法(MCMC)采样(2)
R语言BUGS/JAGS贝叶斯分析: 马尔科夫链蒙特卡洛方法(MCMC)采样(2)
|
6月前
|
算法 数据挖掘
R语言中的贝叶斯统计方法
【4月更文挑战第26天】R语言在贝叶斯统计中发挥着重要作用,提供如&quot;BUGS&quot;、&quot;Stan&quot;、&quot;JAGS&quot;等包来处理复杂模型和数值计算。贝叶斯方法基于概率论,涉及先验分布、似然函数、后验分布和MCMC模拟。&quot;BUGS&quot;适用于复杂层次模型,&quot;Stan&quot;则在大规模数据和复杂模型上有优势。
69 2
|
6月前
|
算法 数据可视化 Python
【视频】逆变换抽样将数据标准化和R语言结构化转换:BOX-COX、凸规则变换方法
【视频】逆变换抽样将数据标准化和R语言结构化转换:BOX-COX、凸规则变换方法