机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题

简介: 机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题

1. K 近邻算法 API

K近邻(K-Nearest Neighbors, KNN)算法作为一种基础且广泛应用的机器学习技术,其API的重要性不言而喻。它提供了快速、直接的方式来执行基于实例的学习,通过查找与待分类样本最邻近的K个样本,并基于这些邻近样本的类别来预测新样本的类别。KNN API的标准化和易用性,使得数据分析师和开发者能够轻松集成该算法到他们的项目中,无需深入算法细节,即可享受其强大的分类与回归能力。此外,KNN API通常还包含参数调整功能,如K值选择、距离度量方法等,使得用户可以根据具体需求优化算法性能,进一步凸显了其在机器学习实践中的不可或缺性。

学习目标

  1. 掌握sklearn中K近邻算法API的使用方法

1.1 Sklearn API介绍

本小节使用 scikit-learn 的 KNN API 来完成对鸢尾花数据集的预测.

  • API介绍

1.2 鸢尾花分类示例代码

鸢尾花数据集

鸢尾花Iris Dataset数据集是机器学习领域经典数据集,鸢尾花数据集包含了150条鸢尾花信息,每50条取自三个鸢尾花中之一:Versicolour、Setosa和Virginica

每个花的特征用如下属性描述:

示例代码:

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
if __name__ == '__main__':
    # 1. 加载数据集  
    iris = load_iris() #通过iris.data 获取数据集中的特征值  iris.target获取目标值
    # 2. 数据标准化
    transformer = StandardScaler()
    x_ = transformer.fit_transform(iris.data) # iris.data 数据的特征值
    # 3. 模型训练
    estimator = KNeighborsClassifier(n_neighbors=3) # n_neighbors 邻居的数量,也就是Knn中的K值
    estimator.fit(x_, iris.target) # 调用fit方法 传入特征和目标进行模型训练
    # 4. 利用模型预测
    result = estimator.predict(x_) 
    print(result)

1.3 小结

1、sklearn中K近邻算法的对象:

from sklearn.neighbors import KNeighborsClassifier
estimator = KNeighborsClassifier(n_neighbors=3)  # K的取值通过n_neighbors传递

2、sklearn中大多数算法模型训练的API都是同一个套路

estimator = KNeighborsClassifier(n_neighbors=3) # 创建算法模型对象
estimator.fit(x_, iris.target)  # 调用fit方法训练模型
estimator.predict(x_)           # 用训练好的模型进行预测

3、sklearn中自带了几个学习数据集

  • 都封装在sklearn.datasets 这个包中
  • 加载数据后,通过data属性可以获取特征值,通过target属性可以获取目标值, 通过DESCR属性可以获取数据集的描述信息

2. K 值选择问题

K值选择问题是K近邻算法中的关键,它直接影响到算法的准确性与效率。在平衡“过拟合”与“欠拟合”需要注意:K值过小可能导致模型复杂,对新样本敏感,易于过拟合;K值过大则可能平滑类边界,忽视邻近样本的细节,造成欠拟合。因此,合理选取K值是确保K近邻算法性能的重要步骤。

学习目标

  1. 了解 K 值大小的影响
  2. 掌握 GridSearchCV 的使用

2.1 K取不同值时带来的影响

举例:

  • 有两类不同的样本数据,分别用蓝颜色的小正方形和红色的小三角形表示,而图正中间有一个绿色的待判样本。
  • 问题:如何给这个绿色的圆分类?是判断为蓝色的小正方形还是红色的小三角形?
  • 方法:应用KNN找绿色的邻居,但一次性看多少个邻居呢(K取几合适)?

解决方案:

  • K=4,绿色圆圈最近的4个邻居,3红色和1个蓝,按少数服从多数,判定绿色样本与红色三角形属于同一类别
  • K=9,绿色圆圈最近的9个邻居,6红和3个蓝,判定绿色属于红色的三角形一类。

有时候出现K值选择困难的问题

KNN算法的关键是什么?

答案一定是K值的选择,下图中K=3,属于红色三角形,K=5属于蓝色的正方形。这个时候就是K选择困难的时候。

2.2 如何确定合适的K值

K值过小:容易受到异常点的影响

k值过大:受到样本均衡的问题

K=N(N为训练样本个数):结果只取决于数据集中不同类别数量占比,得到的结果一定是占比高的类别,此时模型过于简单,忽略了训练实例中大量有用信息。

在实际应用中,K一般取一个较小的数值

我们可以采用交叉验证法(把训练数据再分成:训练集和验证集)来选择最优的K值。

2.3 GridSearchCV 的用法

使用 scikit-learn 提供的 GridSearchCV 工具, 配合交叉验证法可以搜索参数组合.

# 1. 加载数据集
x, y = load_iris(return_X_y=True)
# 2. 分割数据集
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, stratify=y, random_state=0)
# 3. 创建网格搜索对象
estimator = KNeighborsClassifier()
param_grid = {'n_neighbors': [1, 3, 5, 7]}
estimator = GridSearchCV(estimator, param_grid=param_grid, cv=5, verbose=0)
estimator.fit(x_train, y_train)
# 4. 打印最优参数
print('最优参数组合:', estimator.best_params_, '最好得分:', estimator.best_score_)
# 4. 测试集评估模型
print('测试集准确率:', estimator.score(x_test, y_test))

2.4 小结

KNN 算法中K值过大、过小都不好, 一般会取一个较小的值

GridSearchCV 工具可以用来寻找最优的模型超参数,可以用来做KNN中K值的选择

K近邻算法的优缺点:

  • 优点:简单,易于理解,容易实现
  • 缺点:算法复杂度高,结果对K取值敏感,容易受数据分布影响+
相关文章
|
2月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
117 4
|
7天前
|
JSON 安全 API
淘宝商品详情API接口(item get pro接口概述)
淘宝商品详情API接口旨在帮助开发者获取淘宝商品的详细信息,包括商品标题、描述、价格、库存、销量、评价等。这些信息对于电商企业而言具有极高的价值,可用于商品信息展示、市场分析、价格比较等多种应用场景。
|
16天前
|
算法
PAI下面的gbdt、xgboost、ps-smart 算法如何优化?
设置gbdt 、xgboost等算法的样本和特征的采样率
39 2
|
17天前
|
数据采集 监控 数据挖掘
常用电商商品数据API接口(item get)概述,数据分析以及上货
电商商品数据API接口(item get)是电商平台上用于提供商品详细信息的接口。这些接口允许开发者或系统以编程方式获取商品的详细信息,包括但不限于商品的标题、价格、库存、图片、销量、规格参数、用户评价等。这些信息对于电商业务来说至关重要,是商品数据分析、价格监控、上货策略制定等工作的基础。
|
2月前
|
机器学习/深度学习 算法 数据挖掘
C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出
本文探讨了C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出。文章还介绍了C语言在知名机器学习库中的作用,以及与Python等语言结合使用的案例,展望了其未来发展的挑战与机遇。
51 1
|
2月前
|
机器学习/深度学习 自然语言处理 算法
深入理解机器学习算法:从线性回归到神经网络
深入理解机器学习算法:从线性回归到神经网络
|
2月前
|
机器学习/深度学习 算法
深入探索机器学习中的决策树算法
深入探索机器学习中的决策树算法
42 0
|
2月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
105 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
【MM2024】阿里云 PAI 团队图像编辑算法论文入选 MM2024
阿里云人工智能平台 PAI 团队发表的图像编辑算法论文在 MM2024 上正式亮相发表。ACM MM(ACM国际多媒体会议)是国际多媒体领域的顶级会议,旨在为研究人员、工程师和行业专家提供一个交流平台,以展示在多媒体领域的最新研究成果、技术进展和应用案例。其主题涵盖了图像处理、视频分析、音频处理、社交媒体和多媒体系统等广泛领域。此次入选标志着阿里云人工智能平台 PAI 在图像编辑算法方面的研究获得了学术界的充分认可。
【MM2024】阿里云 PAI 团队图像编辑算法论文入选 MM2024
|
2月前
|
机器学习/深度学习 算法 Python
机器学习入门:理解并实现K-近邻算法
机器学习入门:理解并实现K-近邻算法
39 0