轻松玩转 Scikit-Learn 系列 —— KNN 算法

简介: scikit-learn 是最受欢迎的机器学习库之一,它提供了各种主流的机器学习算法的API接口供使用者调用,让使用者可以方便快捷的搭建一些机器学习模型,并且通过调参可以达到很高的准确率。

scikit-learn 是最受欢迎的机器学习库之一,它提供了各种主流的机器学习算法的API接口供使用者调用,让使用者可以方便快捷的搭建一些机器学习模型,并且通过调参可以达到很高的准确率。

这次我们主要介绍scikit-learn中k近邻算法(以下简称为KNN)的使用。

KNN是一种非参数机器学习算法(机器学习中通过模型训练而学到的是模型参数,而要人工调整的是超参数,请注意避免混淆)。使用KNN首先要有一个已知的数据集D,数据集内对于任意一个未知标签的样本数据x,可以通过计算x与D中所有样本点的距离,取出与x距离最近的前k个已知数据,用该k个已知数据的标签对x进行投票,哪一类票数最多,x就是哪一类,这是kNN的大概思想,以下举个例子方便理解。

42.jpg


正方形该分到哪个类?

在上图中有2个已知类别——红色五角星和蓝色三角形和一个未知样本——绿色方格。现在我们要用KNN算法对绿色方格进行分类,以判定其属于这两类中的哪一类,首先令k=5,通过计算距离我们可以知道距离绿色方格最近的5个样本中(假设绿色方格位于圆心),有2个红色五角星,3个蓝色三角形。通过投票可知:蓝色三角形得3票,红色五角星得2票,因此绿色方格应该属于蓝色三角形。kNN就是这样工作的。

上图同时也引申出KNN算法的一个重要的超参数——k。举例来说,如果当k=10时,由图可以看出:红色五角星投了6票,蓝色三角形投了4票,因此未知的样本应该属于红色五角星一类。因此,我们可以看出超参数的选择会影响最终kNN模型的预测结果。下面用代码具体展示如何调用scikit-learn使用kNN,并调整超参数。

43.jpg

44.jpg

👆 取鸢尾花数据集两个特征可视化

45.jpg


以上是利用scikit-learn中默认的k近邻模型来预测未知鸢尾花样本的种类(假装未知),我们在实例化模型的过程中并未传入任何的超参数,则kNN模型会使用模型默认的超参数。

例如:

  • metric='minkowski' —— 计算样本点之间距离的时候会采用明可夫斯基距离,与p=2等价
  • n_jobs=1 —— kNN算法支持cpu多核并行运算;n_jobs=1,默认使用一个核,当n_jobs=-1时,使用所有的核
  • n_neighbors=5 —— 表示k=5,即抽取未知样本附近最近的5个点进行投票
  • weights='uniform' —— 表示再利用最近的k个点投票时,他们的权重是等价的,当weights='distance'时,表示一个已知样本点距离未知点的距离越小,其投票时所占权重越大

还有一些其他的很重要的超参数,在这里先暂不说明,以下用代码具体展示。

46.jpg


以下用循环来搜索下关于n_neighbors、和p这两个超参数的最优值。

47.jpg



因为我们为了便于可视化,仅使用了鸢尾花数据集中的2个特征,所以导致最终预测的准确率不太高,如果使用该数据集的全部特征来训练模型并预测未知样本,传入最佳超参数的kNN模型,亲测准确度可达100%,当然这与鸢尾花数据集的高质量也有关系。运行以上代码并打印结果可得如上所示。

今天的分享就到这里了,关于kNN还有很多更复杂的超参数的调整,就不一一展示了,请小伙伴们自己在下面亲手操作下,会收获更多哦。kNN思想和实现简单,目前还在机器学习算法的领域持续的发光发热,如果你们中有大神路过,还请高抬贵脚,勿踩勿喷!

相关文章
|
1月前
|
机器学习/深度学习 算法
机器学习第14天:KNN近邻算法
机器学习第14天:KNN近邻算法
23 0
|
1月前
|
机器学习/深度学习 数据采集 算法
Machine Learning机器学习之K近邻算法(K-Nearest Neighbors,KNN)
Machine Learning机器学习之K近邻算法(K-Nearest Neighbors,KNN)
|
2天前
|
机器学习/深度学习 存储 算法
用kNN算法诊断乳腺癌--基于R语言
用kNN算法诊断乳腺癌--基于R语言
|
2天前
|
机器学习/深度学习 人工智能 算法
【机器学习】K-means和KNN算法有什么区别?
【5月更文挑战第11天】【机器学习】K-means和KNN算法有什么区别?
|
18天前
|
机器学习/深度学习 自然语言处理 算法
【视频】K近邻KNN算法原理与R语言结合新冠疫情对股票价格预测|数据分享(下)
【视频】K近邻KNN算法原理与R语言结合新冠疫情对股票价格预测|数据分享
|
18天前
|
机器学习/深度学习 算法 大数据
【视频】K近邻KNN算法原理与R语言结合新冠疫情对股票价格预测|数据分享(上)
【视频】K近邻KNN算法原理与R语言结合新冠疫情对股票价格预测|数据分享
|
26天前
|
机器学习/深度学习 算法 前端开发
Scikit-learn进阶:探索集成学习算法
【4月更文挑战第17天】本文介绍了Scikit-learn中的集成学习算法,包括Bagging(如RandomForest)、Boosting(AdaBoost、GradientBoosting)和Stacking。通过结合多个学习器,集成学习能提高模型性能,减少偏差和方差。文中展示了如何使用Scikit-learn实现这些算法,并提供示例代码,帮助读者理解和应用集成学习提升模型预测准确性。
|
27天前
电信公司churn数据客户流失k近邻(knn)模型预测分析
电信公司churn数据客户流失k近邻(knn)模型预测分析
|
3天前
|
算法 数据安全/隐私保护 计算机视觉
基于二维CS-SCHT变换和LABS方法的水印嵌入和提取算法matlab仿真
该内容包括一个算法的运行展示和详细步骤,使用了MATLAB2022a。算法涉及水印嵌入和提取,利用LAB色彩空间可能用于隐藏水印。水印通过二维CS-SCHT变换、低频系数处理和特定解码策略来提取。代码段展示了水印置乱、图像处理(如噪声、旋转、剪切等攻击)以及水印的逆置乱和提取过程。最后,计算并保存了比特率,用于评估水印的稳健性。
|
4天前
|
存储 算法 数据可视化
基于harris角点和RANSAC算法的图像拼接matlab仿真
本文介绍了使用MATLAB2022a进行图像拼接的流程,涉及Harris角点检测和RANSAC算法。Harris角点检测寻找图像中局部曲率变化显著的点,RANSAC则用于排除噪声和异常点,找到最佳匹配。核心程序包括自定义的Harris角点计算函数,RANSAC参数设置,以及匹配点的可视化和仿射变换矩阵计算,最终生成全景图像。