基于 mlr 包的 K 最近邻算法介绍与实践(下)

简介: 在上期 KNN 算法介绍 的最后,我们指出:使用最初用来训练模型的数据进行预测的方式来评估模型性能是不合理的。本期将以上期的内容和数据为基础,介绍交叉验证的方法来评估模型性能、如何选择参数 k 来优化模型等内容。

前言


在上期 KNN 算法介绍 的最后,我们指出:使用最初用来训练模型的数据进行预测的方式来评估模型性能是不合理的。本期将以上期的内容和数据为基础,介绍交叉验证的方法来评估模型性能、如何选择参数 k 来优化模型等内容。


1. 交叉验证


通常情况下,我们会将已有的数据分为两部分:训练集 (training set) 和测试集 (test set)。使用训练集来训练模型,并用测试集的数据来评估模型性能。这个过程叫做交叉验证 (cross-validation)。常见的交叉验证方法有以下三种:

  • Hold-out cross-validation.
  • k-fold cross-validation.
  • leave-one-out cross-validation.

接下来,本文将从上期创建的任务和 learner 出发来分别介绍以上三种交叉验证方法。

diabetesTask <- makeClassifTask(data = diabetesTib, target = "class")
knn <- makeLearner("classif.knn", par.vals = list("k" = 2))


1.1 Hold-out cross-validation

Hold-out cross-validation 是最容易理解的方法:我们只需随机 "保留" 一个一定比例的数据作为测试集,并在剩余数据上训练模型,然后使用测试集来评估模型性能。

image.gifFig 1. Hold-out cross-validation 过程

在采用这种方法时,需要决定将多大比例的数据用作测试集。如果测试集太小,那么对性能的估计就会有很大的方差,但若训练集太小,那么对性能的估计就会有很大的偏差。通常,2/3的数据用于训练集,1/3用作测试集,但这也取决于数据中实例的数量。


1.1.1 Holdout 重采样描述

mlr 包中使用交叉验证,第一步是进行重采样描述,这是一组简单的指令,用于将数据分割成测试集和训练集。

makeResampleDesc() 函数的第一个参数是要使用的交叉验证方法,在本例中是 Holdout;第二个参数 split,用来设定多大比例的数据将被用作训练集;stratify = TRUE 确保在将数据拆分为训练集和测试集时,尽量保持糖尿病患者在每一类中的比例。

holdout <- makeResampleDesc(method = "Holdout", split = 2/3,
 stratify = TRUE)#重采样描述


1.1.2 执行 Hold-out cross-validation

holdoutCV <- resample(learner = knn, task = diabetesTask,
 resampling = holdout, measures = list(mmce, acc))#交叉验证


我们将创建的任务、 learner 和刚才定义的重采样方法提供给 resample() 函数,并要求 resample() 计算 mmceacc

运行上面代码会直接得到结果,也可以使用 holdoutCV$aggr 得到,如下所示:

holdoutCV$aggr
#mmce.test.mean  acc.test.mean 
#     0.1632653      0.8367347


Hold-out cross-validation 得到的模型的准确性低于我们在用来训练完整模型的数据上评估的准确性。这证明了之前的观点,即模型在训练它们的数据上比在未见的数据上表现得更好。


1.1.3 计算混淆矩阵

为了更好地了解哪些实例被正确分类,哪些实例被错误分类,我们可以构造一个混淆矩阵。混淆矩阵是测试集中每个实例的真实类和预测类的表格表示。

mlr 包中,使用 calculateConfusionMatrix() 函数可计算混淆矩阵。该函数的第一个参数为 holdoutCV$pred 部分,包含测试集的真实类和预测类;可选参数 relative 要求函数显示每个类在 truepredicted 类标签中的比例。

#计算混淆矩阵
calculateConfusionMatrix(holdoutCV$pred, relative = TRUE)
#Relative confusion matrix (normalized by row/column):
#          predicted
#true       Chemical  Normal    Overt     -err.-   
#  Chemical 0.83/0.62 0.08/0.04 0.08/0.12 0.17     
#  Normal   0.08/0.12 0.92/0.96 0.00/0.00 0.08     
#  Overt    0.36/0.25 0.00/0.00 0.64/0.88 0.36     
#  -err.-        0.38      0.04      0.12 0.16     
#Absolute confusion matrix:
#          predicted
#true       Chemical Normal Overt -err.-
#  Chemical       10      1     1      2
#  Normal          2     24     0      2
#  Overt           4      0     7      4
#  -err.-          6      1     1      8

绝对混淆矩阵更容易解释。行显示真正类标签,列显示预测类标签。这些数字表示真实类和预测类的每一种组合中的情况数。例如,在这个矩阵中,24 名患者被正确地归类为非糖尿病,但 2 名患者被错误地归类为化学糖尿病。在矩阵的对角线上可以找到正确分类的病人。


相对混淆矩阵中,不是真实类和预测类的组合的情况数,而是比例。/ 前面的数字是这一行在这一列的比例,/ 后面的数字是这一列在这一行的比例。例如,在这个矩阵中,92% 的非糖尿病被正确分类,而 8% 被错误分类为化学糖尿病患者。


混淆矩阵帮助我们了解我们的模型对哪些类分类得好,哪些类分类得差。例如,基于这种交叉验证,我们的模型似乎很难区分非糖尿病患者和化学糖尿病患者。


这种交叉验证方法的唯一真正的好处是它比其他形式的交叉验证计算量更小。这使得它成为计算量大的算法中唯一可行的交叉验证方法。


1.2 k-fold cross-validation

k-fold cross-validation 中,随机地将数据分成大约相等大小的块,称为 fold。然后保留其中一个 fold 作为测试集,并使用剩余的数据作为训练集。使用测试集测试模型,并记录相关的性能指标。使用不同的数据 fold 作为测试集,并执行相同的操作,直到所有的 fold 都被用作测试集。最后将得到的所有性能指标求平均值来作为模型性能的估计。该交叉验证方法过程如 Fig 2 所示:


Fig 2. k-fold cross-validation 过程


通常,实际中更倾向于使用 repeated k-fold cross-validation,而不是普通的 k-fold cross-validation 。k 值的选择取决于数据的大小,但对于许多数据集来说,10 是一个合理的值,即将数据分成 10 个大小相近的 fold ,并执行交叉验证。如果将这个过程重复 5 次,即有 10-fold 交叉验证重复 5 次 (这与 50 次交叉验证不同),模型性能的估计将是 50 个结果的平均值。


1.2.1 执行 k-fold cross-validation

kFold <- makeResampleDesc(method = "RepCV", folds = 10, reps = 50,
                          stratify = TRUE)#重采样描述
kFoldCV <- resample(learner = knn, task = diabetesTask,
                    resampling = kFold, measures = list(mmce, acc))#交叉验证

在重采样描述时,method = "RepCV" 说明使用的是 repeated k-fold cross-validationfold 个数为 10 并重复 50 次,最终会有 500 个计算结果。

提取平均性能度量:

kFoldCV$aggr
# mmce.test.mean  acc.test.mean 
#     0.1030395      0.8969605

因此,该模型平均正确分类 89.7% 的实例,低于用来训练模型的数据的结果。


1.2.2 如何选择重复次数

一种合理的方法是选择在计算上合理的多次重复,运行该过程几次,然后看看平均性能估计是否有很大差异,如果变化很大,应该增加重复的次数。一般来说,重复次数越多,这些估计就越准确和稳定。但是,在某些情况下,更多的重复并不会提高性能评估的准确性或稳定性。


1.2.3 计算混淆矩阵

Hold-out cross-validation 中计算混淆矩阵相同:

calculateConfusionMatrix(kFoldCV$pred, relative = TRUE)


1.3 leave-one-out cross-validation

leave-one-out cross-validation 可以被认为是极端的 k-fold cross-validation: 不是将数据分解成 fold,而是只保留一个观察值作为一个测试集,在剩余数据上训练模型。使用测试集测试模型,并记录相关的性能指标。使用不同的观察值作为测试集,并执行相同的操作,直到所有的观察值都被用作测试集。最后将得到的所有性能指标求平均值来作为模型性能的估计。该交叉验证方法过程如 Fig 3 所示:


Fig 3. leave-one-out cross-validation 过程

对于小数据集,若分成 k 个 fold 会留下一个非常小的训练集,在小数据集上训练的模型的方差往往更高,因为它会受到更多的抽样误差或异常情况的影响。因此,leave-one-out cross-validation 对于小数据集是有用的,它在计算上也比 repeated k-fold cross-validation 更方便。


1.3.1 执行 leave-one-out cross-validation

该交叉验证方法的重采样描述很简单,指定参数 method = "LOO" 即可。因为测试集只有一个实例,故无需设定 stratify = TRUE;因为每个实例都被用作测试集,而所有其他数据都被用作训练集,所以不需要重复这个过程。

LOO <- makeResampleDesc(method = "LOO")#重采样描述

运行交叉验证并获得平均性能度量:

LOOCV <- resample(learner = knn, task = diabetesTask, resampling = LOO,
                  measures = list(mmce, acc))#交叉验证
LOOCV$aggr
# mmce.test.mean  acc.test.mean 
#    0.08965517     0.91034483


1.3.2 计算混淆矩阵

calculateConfusionMatrix(LOOCV$pred, relative = TRUE)

现在我们已经知道如何应用三种常用的交叉验证方法。如果我们已经交叉验证了我们的模型,并且它能够在未见的数据上表现得足够好,那么就可以在所有可用的数据上训练这个模型,并使用它来做未来的预测。


2. 如何选择参数 k 来优化 KNN 模型

在 KNN 算法中, k 属于超参数,即可以控制模型预测效果的变量或选项,不能由数据进行估计得到。通常有以下三种方法来选择超参数:

  • 选择一个“合理的”或默认值,它以前处理过类似的问题。
  • 手动尝试几个不同的值,看看哪个值的性能最好。
  • 使用称为 hyperparameter tuning 的自动选择过程。

其中第三种方法是最优的,下面将着重介绍第三种方法:

  • Step 1. 定义超参数及范围(超参数空间)。
knnParamSpace <- makeParamSet(makeDiscreteParam("k", values = 1:10))

makeParamSet() 函数中指定要调优的参数 k,范围为 1-10。makeDiscreteParam() 函数用于定义离散的超参数。如果想在调优过程中调优多个超参数,只需在函数内部用逗号将它们分隔开。

  • Step 2. 搜索超参数空间。


事实上,搜索方法有很多种,下面我们将使用网格搜索 (grid search)。这可能是最简单的方法,在寻找最佳性能值时,只需尝试超参数空间中的每一个值。对于连续超参数或有多个超参数时,更倾向于使用 random search

gridSearch <- makeTuneControlGrid()
  • Step 3. 交叉验证调优过程。
cvForTuning <- makeResampleDesc("RepCV", folds = 10, reps = 20)

这里使用的交叉验证方法为 repeated k-fold cross-validation。对于 每一个 k 值,在所有这些迭代中进行平均性能度量,并与所有其他 k 值的平均性能度量比较。

  • Step 4. 调用函数 tuneParams() 调优
tunedK <- tuneParams("classif.knn", task = diabetesTask,
 resampling = cvForTuning,
 par.set = knnParamSpace, control = gridSearch)


其中,第一个参数为算法名称,第二个参数为之前定义的任务,第三个参数为交叉验证调优方法,第四个参数为定义的超参数空间,最后一个参数为搜索方法。

调用 tunedK 可得到最优的 k 值:

tunedK
#Tune result:
#Op. pars: k=7
#mmce.test.mean=0.0750476


也可以通过选择 $x 组件直接得到性能最好的 k 值:

tunedK$x
#$k
#[1] 7


另外,还可以可视化调优过程:

knnTuningData <- generateHyperParsEffectData(tunedK)
plotHyperParsEffect(knnTuningData, x = "k", y = "mmce.test.mean",
 plot.type = "line") +
 theme_bw()


Fig 4. 可视化调优过程

最终,我们可以使用调优得到的 k 值训练我们的最终模型:

tunedKnn <- setHyperPars(makeLearner("classif.knn"),
 par.vals = tunedK$x)
tunedKnnModel <- train(tunedKnn, diabetesTask)

类似于 makeLearner() 函数,在 setHyperPars() 函数中创建了一个新的 learner。再使用 train() 函数训练最终的模型。


3. 嵌套交叉验证


3.1 嵌套交叉验证

当我们对数据或模型执行某种预处理时,比如调优超参数,重要的是要将这种预处理包括到交叉验证中,这样就可以交叉验证整个模型训练过程。

这采用了嵌套交叉验证的形式,其中有一个内部循环来交叉验证超参数的不同值(就像上面做的那样),然后,最优的超参数值被传递到外部交叉验证循环。在外部交叉验证循环中,每个 fold 都使用最优超参数。


Fig 5. 嵌套交叉验证


在 Fig 5 中,外部是 3-fold cross-validation 循环,对于每个 fold,只使用外部循环的训练集来进行内部 4-fold cross-validation。对于每个内部循环,使用不同的 k 值,最优的 k 值被传递到外部循环中用来训练模型并使用测试集评估模型性能。


使用 mlr 包中的函数可以很简单地实现嵌套交叉验证过程。

  • Step 1. 定义外部和内部交叉验证。
inner <- makeResampleDesc("CV")
outer <- makeResampleDesc("RepCV", folds = 10, reps = 5)

对内部循环执行普通 k-fold cross-validation(10 是默认的折叠次数),对外部循环执行 10-fold cross-validation (重复 5 次)。

  • Step 2. 定义 wrapper

基本上是一个 learner,与一些预处理步骤联系在一起,在本文的例子中,是超参数调优,故使用函数 makeTuneWrapper():

knnWrapper <- makeTuneWrapper("classif.knn", resampling = inner,
 par.set = knnParamSpace,
 control = gridSearch)

函数 makeTuneWrapper() 中第一个参数为算法,第二个为重采样参数,为内部交叉验证过程,第三个为 par.set 参数,是超参数搜索空间,第四个 control 参数为 gridSearch 方法。

  • Step 3. 运行嵌套交叉验证过程。
cvWithTuning <- resample(knnWrapper, diabetesTask, resampling = outer)

第一个参数是我们刚才创建的 wrapper ,第二个参数是任务的名称,第三个参数设为外部交叉验证。

调用 cvWithTuning 可得结果:

cvWithTuning
#Resample Result
#Task: diabetesTib
#Learner: classif.knn.tuned
#Aggr perf: mmce.test.mean=0.0857143
#Runtime: 57.1177

对于未见的数据,该模型估计能正确分类91.4%的病例。


3.2 利用模型进行预测

假设有一些新的病人来到诊所:

newDiabetesPatients <- tibble(glucose = c(82, 108, 300),
 insulin = c(361, 288, 1052),
 sspg = c(200, 186, 135))

将这些患者输入到模型中,得到他们的预测糖尿病状态:

newPatientsPred <- predict(tunedKnnModel, newdata = newDiabetesPatients)
getPredictionResponse(newPatientsPred)
#[1] Normal Normal Overt 
#Levels: Chemical Normal Overt
目录
相关文章
|
20天前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
115 30
|
24天前
|
存储 算法
深入解析PID控制算法:从理论到实践的完整指南
前言 大家好,今天我们介绍一下经典控制理论中的PID控制算法,并着重讲解该算法的编码实现,为实现后续的倒立摆样例内容做准备。 众所周知,掌握了 PID ,就相当于进入了控制工程的大门,也能为更高阶的控制理论学习打下基础。 在很多的自动化控制领域。都会遇到PID控制算法,这种算法具有很好的控制模式,可以让系统具有很好的鲁棒性。 基本介绍 PID 深入理解 (1)闭环控制系统:讲解 PID 之前,我们先解释什么是闭环控制系统。简单说就是一个有输入有输出的系统,输入能影响输出。一般情况下,人们也称输出为反馈,因此也叫闭环反馈控制系统。比如恒温水池,输入就是加热功率,输出就是水温度;比如冷库,
192 15
|
1月前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
数据结构与算法系列学习之串的定义和基本操作、串的储存结构、基本操作的实现、朴素模式匹配算法、KMP算法等代码举例及图解说明;【含常见的报错问题及其对应的解决方法】你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习(8)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之单双链表精题详解(9)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第2.3章之IKUN和I原达人之数据结构与算法系列学习x单双链表精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
存储 Web App开发 算法
2024重生之回溯数据结构与算法系列学习之单双链表【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构之单双链表按位、值查找;[前后]插入;删除指定节点;求表长、静态链表等代码及具体思路详解步骤;举例说明、注意点及常见报错问题所对应的解决方法
|
1月前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之栈和队列精题汇总(10)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第3章之IKUN和I原达人之数据结构与算法系列学习栈与队列精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
1月前
|
算法 安全 搜索推荐
2024重生之回溯数据结构与算法系列学习之王道第2.3章节之线性表精题汇总二(5)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
IKU达人之数据结构与算法系列学习×单双链表精题详解、数据结构、C++、排序算法、java 、动态规划 你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
|
2天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
104 80
|
21天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。