基于 mlr 包的 K 最近邻算法介绍与实践(下)

简介: 在上期 KNN 算法介绍 的最后,我们指出:使用最初用来训练模型的数据进行预测的方式来评估模型性能是不合理的。本期将以上期的内容和数据为基础,介绍交叉验证的方法来评估模型性能、如何选择参数 k 来优化模型等内容。

前言


在上期 KNN 算法介绍 的最后,我们指出:使用最初用来训练模型的数据进行预测的方式来评估模型性能是不合理的。本期将以上期的内容和数据为基础,介绍交叉验证的方法来评估模型性能、如何选择参数 k 来优化模型等内容。


1. 交叉验证


通常情况下,我们会将已有的数据分为两部分:训练集 (training set) 和测试集 (test set)。使用训练集来训练模型,并用测试集的数据来评估模型性能。这个过程叫做交叉验证 (cross-validation)。常见的交叉验证方法有以下三种:

  • Hold-out cross-validation.
  • k-fold cross-validation.
  • leave-one-out cross-validation.

接下来,本文将从上期创建的任务和 learner 出发来分别介绍以上三种交叉验证方法。

diabetesTask <- makeClassifTask(data = diabetesTib, target = "class")
knn <- makeLearner("classif.knn", par.vals = list("k" = 2))


1.1 Hold-out cross-validation

Hold-out cross-validation 是最容易理解的方法:我们只需随机 "保留" 一个一定比例的数据作为测试集,并在剩余数据上训练模型,然后使用测试集来评估模型性能。

image.gifFig 1. Hold-out cross-validation 过程

在采用这种方法时,需要决定将多大比例的数据用作测试集。如果测试集太小,那么对性能的估计就会有很大的方差,但若训练集太小,那么对性能的估计就会有很大的偏差。通常,2/3的数据用于训练集,1/3用作测试集,但这也取决于数据中实例的数量。


1.1.1 Holdout 重采样描述

mlr 包中使用交叉验证,第一步是进行重采样描述,这是一组简单的指令,用于将数据分割成测试集和训练集。

makeResampleDesc() 函数的第一个参数是要使用的交叉验证方法,在本例中是 Holdout;第二个参数 split,用来设定多大比例的数据将被用作训练集;stratify = TRUE 确保在将数据拆分为训练集和测试集时,尽量保持糖尿病患者在每一类中的比例。

holdout <- makeResampleDesc(method = "Holdout", split = 2/3,
 stratify = TRUE)#重采样描述


1.1.2 执行 Hold-out cross-validation

holdoutCV <- resample(learner = knn, task = diabetesTask,
 resampling = holdout, measures = list(mmce, acc))#交叉验证


我们将创建的任务、 learner 和刚才定义的重采样方法提供给 resample() 函数,并要求 resample() 计算 mmceacc

运行上面代码会直接得到结果,也可以使用 holdoutCV$aggr 得到,如下所示:

holdoutCV$aggr
#mmce.test.mean  acc.test.mean 
#     0.1632653      0.8367347


Hold-out cross-validation 得到的模型的准确性低于我们在用来训练完整模型的数据上评估的准确性。这证明了之前的观点,即模型在训练它们的数据上比在未见的数据上表现得更好。


1.1.3 计算混淆矩阵

为了更好地了解哪些实例被正确分类,哪些实例被错误分类,我们可以构造一个混淆矩阵。混淆矩阵是测试集中每个实例的真实类和预测类的表格表示。

mlr 包中,使用 calculateConfusionMatrix() 函数可计算混淆矩阵。该函数的第一个参数为 holdoutCV$pred 部分,包含测试集的真实类和预测类;可选参数 relative 要求函数显示每个类在 truepredicted 类标签中的比例。

#计算混淆矩阵
calculateConfusionMatrix(holdoutCV$pred, relative = TRUE)
#Relative confusion matrix (normalized by row/column):
#          predicted
#true       Chemical  Normal    Overt     -err.-   
#  Chemical 0.83/0.62 0.08/0.04 0.08/0.12 0.17     
#  Normal   0.08/0.12 0.92/0.96 0.00/0.00 0.08     
#  Overt    0.36/0.25 0.00/0.00 0.64/0.88 0.36     
#  -err.-        0.38      0.04      0.12 0.16     
#Absolute confusion matrix:
#          predicted
#true       Chemical Normal Overt -err.-
#  Chemical       10      1     1      2
#  Normal          2     24     0      2
#  Overt           4      0     7      4
#  -err.-          6      1     1      8

绝对混淆矩阵更容易解释。行显示真正类标签,列显示预测类标签。这些数字表示真实类和预测类的每一种组合中的情况数。例如,在这个矩阵中,24 名患者被正确地归类为非糖尿病,但 2 名患者被错误地归类为化学糖尿病。在矩阵的对角线上可以找到正确分类的病人。


相对混淆矩阵中,不是真实类和预测类的组合的情况数,而是比例。/ 前面的数字是这一行在这一列的比例,/ 后面的数字是这一列在这一行的比例。例如,在这个矩阵中,92% 的非糖尿病被正确分类,而 8% 被错误分类为化学糖尿病患者。


混淆矩阵帮助我们了解我们的模型对哪些类分类得好,哪些类分类得差。例如,基于这种交叉验证,我们的模型似乎很难区分非糖尿病患者和化学糖尿病患者。


这种交叉验证方法的唯一真正的好处是它比其他形式的交叉验证计算量更小。这使得它成为计算量大的算法中唯一可行的交叉验证方法。


1.2 k-fold cross-validation

k-fold cross-validation 中,随机地将数据分成大约相等大小的块,称为 fold。然后保留其中一个 fold 作为测试集,并使用剩余的数据作为训练集。使用测试集测试模型,并记录相关的性能指标。使用不同的数据 fold 作为测试集,并执行相同的操作,直到所有的 fold 都被用作测试集。最后将得到的所有性能指标求平均值来作为模型性能的估计。该交叉验证方法过程如 Fig 2 所示:


Fig 2. k-fold cross-validation 过程


通常,实际中更倾向于使用 repeated k-fold cross-validation,而不是普通的 k-fold cross-validation 。k 值的选择取决于数据的大小,但对于许多数据集来说,10 是一个合理的值,即将数据分成 10 个大小相近的 fold ,并执行交叉验证。如果将这个过程重复 5 次,即有 10-fold 交叉验证重复 5 次 (这与 50 次交叉验证不同),模型性能的估计将是 50 个结果的平均值。


1.2.1 执行 k-fold cross-validation

kFold <- makeResampleDesc(method = "RepCV", folds = 10, reps = 50,
                          stratify = TRUE)#重采样描述
kFoldCV <- resample(learner = knn, task = diabetesTask,
                    resampling = kFold, measures = list(mmce, acc))#交叉验证

在重采样描述时,method = "RepCV" 说明使用的是 repeated k-fold cross-validationfold 个数为 10 并重复 50 次,最终会有 500 个计算结果。

提取平均性能度量:

kFoldCV$aggr
# mmce.test.mean  acc.test.mean 
#     0.1030395      0.8969605

因此,该模型平均正确分类 89.7% 的实例,低于用来训练模型的数据的结果。


1.2.2 如何选择重复次数

一种合理的方法是选择在计算上合理的多次重复,运行该过程几次,然后看看平均性能估计是否有很大差异,如果变化很大,应该增加重复的次数。一般来说,重复次数越多,这些估计就越准确和稳定。但是,在某些情况下,更多的重复并不会提高性能评估的准确性或稳定性。


1.2.3 计算混淆矩阵

Hold-out cross-validation 中计算混淆矩阵相同:

calculateConfusionMatrix(kFoldCV$pred, relative = TRUE)


1.3 leave-one-out cross-validation

leave-one-out cross-validation 可以被认为是极端的 k-fold cross-validation: 不是将数据分解成 fold,而是只保留一个观察值作为一个测试集,在剩余数据上训练模型。使用测试集测试模型,并记录相关的性能指标。使用不同的观察值作为测试集,并执行相同的操作,直到所有的观察值都被用作测试集。最后将得到的所有性能指标求平均值来作为模型性能的估计。该交叉验证方法过程如 Fig 3 所示:


Fig 3. leave-one-out cross-validation 过程

对于小数据集,若分成 k 个 fold 会留下一个非常小的训练集,在小数据集上训练的模型的方差往往更高,因为它会受到更多的抽样误差或异常情况的影响。因此,leave-one-out cross-validation 对于小数据集是有用的,它在计算上也比 repeated k-fold cross-validation 更方便。


1.3.1 执行 leave-one-out cross-validation

该交叉验证方法的重采样描述很简单,指定参数 method = "LOO" 即可。因为测试集只有一个实例,故无需设定 stratify = TRUE;因为每个实例都被用作测试集,而所有其他数据都被用作训练集,所以不需要重复这个过程。

LOO <- makeResampleDesc(method = "LOO")#重采样描述

运行交叉验证并获得平均性能度量:

LOOCV <- resample(learner = knn, task = diabetesTask, resampling = LOO,
                  measures = list(mmce, acc))#交叉验证
LOOCV$aggr
# mmce.test.mean  acc.test.mean 
#    0.08965517     0.91034483


1.3.2 计算混淆矩阵

calculateConfusionMatrix(LOOCV$pred, relative = TRUE)

现在我们已经知道如何应用三种常用的交叉验证方法。如果我们已经交叉验证了我们的模型,并且它能够在未见的数据上表现得足够好,那么就可以在所有可用的数据上训练这个模型,并使用它来做未来的预测。


2. 如何选择参数 k 来优化 KNN 模型

在 KNN 算法中, k 属于超参数,即可以控制模型预测效果的变量或选项,不能由数据进行估计得到。通常有以下三种方法来选择超参数:

  • 选择一个“合理的”或默认值,它以前处理过类似的问题。
  • 手动尝试几个不同的值,看看哪个值的性能最好。
  • 使用称为 hyperparameter tuning 的自动选择过程。

其中第三种方法是最优的,下面将着重介绍第三种方法:

  • Step 1. 定义超参数及范围(超参数空间)。
knnParamSpace <- makeParamSet(makeDiscreteParam("k", values = 1:10))

makeParamSet() 函数中指定要调优的参数 k,范围为 1-10。makeDiscreteParam() 函数用于定义离散的超参数。如果想在调优过程中调优多个超参数,只需在函数内部用逗号将它们分隔开。

  • Step 2. 搜索超参数空间。


事实上,搜索方法有很多种,下面我们将使用网格搜索 (grid search)。这可能是最简单的方法,在寻找最佳性能值时,只需尝试超参数空间中的每一个值。对于连续超参数或有多个超参数时,更倾向于使用 random search

gridSearch <- makeTuneControlGrid()
  • Step 3. 交叉验证调优过程。
cvForTuning <- makeResampleDesc("RepCV", folds = 10, reps = 20)

这里使用的交叉验证方法为 repeated k-fold cross-validation。对于 每一个 k 值,在所有这些迭代中进行平均性能度量,并与所有其他 k 值的平均性能度量比较。

  • Step 4. 调用函数 tuneParams() 调优
tunedK <- tuneParams("classif.knn", task = diabetesTask,
 resampling = cvForTuning,
 par.set = knnParamSpace, control = gridSearch)


其中,第一个参数为算法名称,第二个参数为之前定义的任务,第三个参数为交叉验证调优方法,第四个参数为定义的超参数空间,最后一个参数为搜索方法。

调用 tunedK 可得到最优的 k 值:

tunedK
#Tune result:
#Op. pars: k=7
#mmce.test.mean=0.0750476


也可以通过选择 $x 组件直接得到性能最好的 k 值:

tunedK$x
#$k
#[1] 7


另外,还可以可视化调优过程:

knnTuningData <- generateHyperParsEffectData(tunedK)
plotHyperParsEffect(knnTuningData, x = "k", y = "mmce.test.mean",
 plot.type = "line") +
 theme_bw()


Fig 4. 可视化调优过程

最终,我们可以使用调优得到的 k 值训练我们的最终模型:

tunedKnn <- setHyperPars(makeLearner("classif.knn"),
 par.vals = tunedK$x)
tunedKnnModel <- train(tunedKnn, diabetesTask)

类似于 makeLearner() 函数,在 setHyperPars() 函数中创建了一个新的 learner。再使用 train() 函数训练最终的模型。


3. 嵌套交叉验证


3.1 嵌套交叉验证

当我们对数据或模型执行某种预处理时,比如调优超参数,重要的是要将这种预处理包括到交叉验证中,这样就可以交叉验证整个模型训练过程。

这采用了嵌套交叉验证的形式,其中有一个内部循环来交叉验证超参数的不同值(就像上面做的那样),然后,最优的超参数值被传递到外部交叉验证循环。在外部交叉验证循环中,每个 fold 都使用最优超参数。


Fig 5. 嵌套交叉验证


在 Fig 5 中,外部是 3-fold cross-validation 循环,对于每个 fold,只使用外部循环的训练集来进行内部 4-fold cross-validation。对于每个内部循环,使用不同的 k 值,最优的 k 值被传递到外部循环中用来训练模型并使用测试集评估模型性能。


使用 mlr 包中的函数可以很简单地实现嵌套交叉验证过程。

  • Step 1. 定义外部和内部交叉验证。
inner <- makeResampleDesc("CV")
outer <- makeResampleDesc("RepCV", folds = 10, reps = 5)

对内部循环执行普通 k-fold cross-validation(10 是默认的折叠次数),对外部循环执行 10-fold cross-validation (重复 5 次)。

  • Step 2. 定义 wrapper

基本上是一个 learner,与一些预处理步骤联系在一起,在本文的例子中,是超参数调优,故使用函数 makeTuneWrapper():

knnWrapper <- makeTuneWrapper("classif.knn", resampling = inner,
 par.set = knnParamSpace,
 control = gridSearch)

函数 makeTuneWrapper() 中第一个参数为算法,第二个为重采样参数,为内部交叉验证过程,第三个为 par.set 参数,是超参数搜索空间,第四个 control 参数为 gridSearch 方法。

  • Step 3. 运行嵌套交叉验证过程。
cvWithTuning <- resample(knnWrapper, diabetesTask, resampling = outer)

第一个参数是我们刚才创建的 wrapper ,第二个参数是任务的名称,第三个参数设为外部交叉验证。

调用 cvWithTuning 可得结果:

cvWithTuning
#Resample Result
#Task: diabetesTib
#Learner: classif.knn.tuned
#Aggr perf: mmce.test.mean=0.0857143
#Runtime: 57.1177

对于未见的数据,该模型估计能正确分类91.4%的病例。


3.2 利用模型进行预测

假设有一些新的病人来到诊所:

newDiabetesPatients <- tibble(glucose = c(82, 108, 300),
 insulin = c(361, 288, 1052),
 sspg = c(200, 186, 135))

将这些患者输入到模型中,得到他们的预测糖尿病状态:

newPatientsPred <- predict(tunedKnnModel, newdata = newDiabetesPatients)
getPredictionResponse(newPatientsPred)
#[1] Normal Normal Overt 
#Levels: Chemical Normal Overt
目录
相关文章
机器学习/深度学习 算法 自动驾驶
301 0
|
2月前
|
算法 API 数据安全/隐私保护
深度解析京东图片搜索API:从图像识别到商品匹配的算法实践
京东图片搜索API基于图像识别技术,支持通过上传图片或图片URL搜索相似商品,提供智能匹配、结果筛选、分页查询等功能。适用于比价、竞品分析、推荐系统等场景。支持Python等开发语言,提供详细请求示例与文档。
|
5月前
|
监控 算法 安全
公司电脑监控软件关键技术探析:C# 环形缓冲区算法的理论与实践
环形缓冲区(Ring Buffer)是企业信息安全管理中电脑监控系统设计的核心数据结构,适用于高并发、高速率与短时有效的多源异构数据处理场景。其通过固定大小的连续内存空间实现闭环存储,具备内存优化、操作高效、数据时效管理和并发支持等优势。文章以C#语言为例,展示了线程安全的环形缓冲区实现,并结合URL访问记录监控应用场景,分析了其在流量削峰、关键数据保护和高性能处理中的适配性。该结构在日志捕获和事件缓冲中表现出色,对提升监控系统效能具有重要价值。
124 1
|
6月前
|
监控 算法 数据处理
基于 C++ 的 KD 树算法在监控局域网屏幕中的理论剖析与工程实践研究
本文探讨了KD树在局域网屏幕监控中的应用,通过C++实现其构建与查询功能,显著提升多维数据处理效率。KD树作为一种二叉空间划分结构,适用于屏幕图像特征匹配、异常画面检测及数据压缩传输优化等场景。相比传统方法,基于KD树的方案检索效率提升2-3个数量级,但高维数据退化和动态更新等问题仍需进一步研究。未来可通过融合其他数据结构、引入深度学习及开发增量式更新算法等方式优化性能。
169 17
|
6月前
|
存储 算法 安全
如何控制上网行为——基于 C# 实现布隆过滤器算法的上网行为管控策略研究与实践解析
在数字化办公生态系统中,企业对员工网络行为的精细化管理已成为保障网络安全、提升组织效能的核心命题。如何在有效防范恶意网站访问、数据泄露风险的同时,避免过度管控对正常业务运作的负面影响,构成了企业网络安全领域的重要研究方向。在此背景下,数据结构与算法作为底层技术支撑,其重要性愈发凸显。本文将以布隆过滤器算法为研究对象,基于 C# 编程语言开展理论分析与工程实践,系统探讨该算法在企业上网行为管理中的应用范式。
162 8
|
6月前
|
存储 监控 算法
基于 C# 时间轮算法的控制局域网上网时间与实践应用
在数字化办公与教育环境中,局域网作为内部网络通信的核心基础设施,其精细化管理水平直接影响网络资源的合理配置与使用效能。对局域网用户上网时间的有效管控,已成为企业、教育机构等组织的重要管理需求。这一需求不仅旨在提升员工工作效率、规范学生网络使用行为,更是优化网络带宽资源分配的关键举措。时间轮算法作为一种经典的定时任务管理机制,在局域网用户上网时间管控场景中展现出显著的技术优势。本文将系统阐述时间轮算法的核心原理,并基于 C# 编程语言提供具体实现方案,以期深入剖析该算法在局域网管理中的应用逻辑与实践价值。
137 5
|
11月前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
1209 30
|
11月前
|
存储 算法
深入解析PID控制算法:从理论到实践的完整指南
前言 大家好,今天我们介绍一下经典控制理论中的PID控制算法,并着重讲解该算法的编码实现,为实现后续的倒立摆样例内容做准备。 众所周知,掌握了 PID ,就相当于进入了控制工程的大门,也能为更高阶的控制理论学习打下基础。 在很多的自动化控制领域。都会遇到PID控制算法,这种算法具有很好的控制模式,可以让系统具有很好的鲁棒性。 基本介绍 PID 深入理解 (1)闭环控制系统:讲解 PID 之前,我们先解释什么是闭环控制系统。简单说就是一个有输入有输出的系统,输入能影响输出。一般情况下,人们也称输出为反馈,因此也叫闭环反馈控制系统。比如恒温水池,输入就是加热功率,输出就是水温度;比如冷库,
1463 15
|
12月前
|
存储 算法 安全
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
数据结构与算法系列学习之串的定义和基本操作、串的储存结构、基本操作的实现、朴素模式匹配算法、KMP算法等代码举例及图解说明;【含常见的报错问题及其对应的解决方法】你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!
2024重生之回溯数据结构与算法系列学习之串(12)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丟脸好嘛?】
|
12月前
|
算法 安全 NoSQL
2024重生之回溯数据结构与算法系列学习之栈和队列精题汇总(10)【无论是王道考研人还是IKUN都能包会的;不然别给我家鸽鸽丢脸好嘛?】
数据结构王道第3章之IKUN和I原达人之数据结构与算法系列学习栈与队列精题详解、数据结构、C++、排序算法、java、动态规划你个小黑子;这都学不会;能不能不要给我家鸽鸽丢脸啊~除了会黑我家鸽鸽还会干嘛?!!!

热门文章

最新文章