Python实现KNN算法和交叉验证

简介: Python实现KNN算法和交叉验证

KNN基础知识


KNN(K-Nearest Neighbors)算法原理


       “近朱者赤,近墨者黑”——从训练数据集中找出和待预测样本 最接近的K个样本,然后 投票决定待预测样本的分类;如果是回归问题,则求出K个样本的平 均值作为待预测样本最 终的预测值

样本距离公式



特征标准化问题


如果样本的多个特征值差别很大,或者样本特征的量纲不一致, 导致样本间距离被某些 特征所主导,就应该考虑样本特征标准化的问题


最常用的特征标准化方法是:z-score标准化


z-score标准化通过sklearn中的 sklearn.preprocessing.StandardScaler实现


实战——使用KNN完成鸢尾花分类预测

在sklearn中使用sklearn.neighbors.KNeighborsClassifier实现 KNN分类的功能

from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data # 所有样本特征
y = iris.target  # 所有样本标签
# 将数据集拆分成训练样本集和测试样本集
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=666)
# 样本特征标准化
std = StandardScaler()
# 通过训练样本集特征进行标准化拟合并转换
X_train_standard = std.fit_transform(X_train)  
# 对于测试集,直接转换即可!
X_test_standard = std.transform(X_test)
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=3)  # 创建KNN分类器对象,K=3
knn.fit(X_train_standard,y_train)  # 拟合
# 根据样本特征,对测试样本进行预测
y_predict = knn.predict(X_test_standard)
# 直接调用score方法,得出分类准确率
print(knn.score(X_test_standard,y_test))


准确率100%,说明这原始自带数据集特征很明显


交叉验证


什么是交叉验证(Cross Validation)


       交叉验证是一种模型选择方法和调参方法,它随机地将数据集 切分成三部分,分别为训 练集(training set)、验证集(validation set)和测试集(test set)。训练 集用来训练模型,验证 集用于模型的选择,测试集用于最终对学习方法的评估。


K折交叉验证(k-fold cross validation)


首先随机地将已给训练数据集切分为k个互不相交的大小相同的 子集;然后利用K-1个子 集的数据训练模型,利用余下的子集验证模型;将这一过程对可能 的K种选择重复进行(这 一过程使用的是同一组超参数);最后通过计算K次的预测误差,对 其平均便会得到1个交 叉验证误差(也就是这一组超参数的预测误差或成绩)。



留一交叉验证(leave-one-out cross validation)


留一交叉验证(留一法)是K折交叉验证的特殊情形,即: K=N,这里N是给定训练数 据集的容量。 留一法不受随机样本划分方式的影响,最接近模型真正的性能 指标。因为N个样本只有 唯一的方式划分为N个子集——每个子集包含一个样本。 缺点:计算量巨大。


实战——手写数字图片数据集的调 参、分类识别

交叉验证使用sklearn.model_selection.cross_val_score来完成

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 数据加载,展示图像
digits = datasets.load_digits()
images = digits.images   # 所有图像数据
plt.gray()  # 灰色图像
plt.matshow(images[0])  # 显示第一个图像
plt.show()


可以看出这是一个数字0


X = digits.data  # 样本特征
y = digits.target   # 样本标签
# 拆分数据集
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.4,random_state=666)
# 交叉验证开始
from sklearn.model_selection import cross_val_score
best_k,best_p,best_score = 0,0,0
for k in range(2,11):  # 外层循环搜索k
    for p in range(1,6):  # 内层循环搜索p
        knn = KNeighborsClassifier(weights="distance",n_neighbors=k,p=p)
        scores = cross_val_score(knn,X_train,y_train,cv=3)  # 3折交叉验证
        score = np.mean(scores)  # 当前这一组超参数在验证集上的平均分
        if score > best_score:
            best_k,best_p,best_score = k,p,score
print("best_k=",best_k)
print("best_p=",best_p)
print("验证最好成绩:best_score=",best_score)

# 使用调好的超参数进行训练与测试
best_knn = KNeighborsClassifier(weights="distance",n_neighbors=2,p=2)
best_knn.fit(X_train,y_train)
best_knn.score(X_test,y_test)  # 测试集上最终的分数

实战——使用网格搜索进行调参

什么是网格搜索


网格搜索可以实现自动调参并返回最佳的参数组合

网格搜索,搜索的是参数,即在指定的参数范围内,依次调整参 数,利用调整的参数训练学习器

网格搜索的sklearn实现


使用sklearn.model_selection.GridSearchCV实现网格搜索 GridSearchCV的名字可以拆分为两部分,GridSearch和CV,即 网格搜索和交叉验证  

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
digits = datasets.load_digits()
X = digits.data  # 样本特征
y = digits.target  # 样本标签
# 拆分数据集
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.4,random_state=666)
# 网格搜索
from sklearn.model_selection import GridSearchCV
# 组装待搜索的超参数组合
param_grid = [
    {
        "weights":["uniform"],
        "n_neighbors":[i for i in range(1,11)]
    },
    {
        "weights":["distance"],
        "n_neighbors":[i for i in range(1,11)],
        "p":[i for i in range(1,6)]
    }
]
knn = KNeighborsClassifier()
gs = GridSearchCV(knn,param_grid,cv=3,n_jobs=-1)
gs.fit(X_train,y_train)  # 搜索最佳超参数组合(很耗时!)
print(gs.best_params_)  # 最佳超参数组合

print(gs.best_score_)  # 最佳验证成绩

# 携带最佳超参数组合的KNeighborsClassifier对象
best_knn = gs.best_estimator_
best_knn.fit(X_train,y_train)  # 使用最佳超参数组合的分类器进行拟合训练
print("在测试集上的最终评估效果:",best_knn.score(X_test,y_test))




目录
相关文章
|
18天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
201 55
|
7天前
|
存储 缓存 监控
局域网屏幕监控系统中的Python数据结构与算法实现
局域网屏幕监控系统用于实时捕获和监控局域网内多台设备的屏幕内容。本文介绍了一种基于Python双端队列(Deque)实现的滑动窗口数据缓存机制,以处理连续的屏幕帧数据流。通过固定长度的窗口,高效增删数据,确保低延迟显示和存储。该算法适用于数据压缩、异常检测等场景,保证系统在高负载下稳定运行。 本文转载自:https://www.vipshare.com
101 66
|
2月前
|
搜索推荐 Python
利用Python内置函数实现的冒泡排序算法
在上述代码中,`bubble_sort` 函数接受一个列表 `arr` 作为输入。通过两层循环,外层循环控制排序的轮数,内层循环用于比较相邻的元素并进行交换。如果前一个元素大于后一个元素,就将它们交换位置。
138 67
|
2月前
|
存储 搜索推荐 Python
用 Python 实现快速排序算法。
快速排序的平均时间复杂度为$O(nlogn)$,空间复杂度为$O(logn)$。它在大多数情况下表现良好,但在某些特殊情况下可能会退化为最坏情况,时间复杂度为$O(n^2)$。你可以根据实际需求对代码进行调整和修改,或者尝试使用其他优化策略来提高快速排序的性能
128 61
|
28天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
153 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
11天前
|
存储 运维 监控
探索局域网电脑监控软件:Python算法与数据结构的巧妙结合
在数字化时代,局域网电脑监控软件成为企业管理和IT运维的重要工具,确保数据安全和网络稳定。本文探讨其背后的关键技术——Python中的算法与数据结构,如字典用于高效存储设备信息,以及数据收集、异常检测和聚合算法提升监控效率。通过Python代码示例,展示了如何实现基本监控功能,帮助读者理解其工作原理并激发技术兴趣。
47 20
|
3天前
|
算法 网络协议 Python
探秘Win11共享文件夹之Python网络通信算法实现
本文探讨了Win11共享文件夹背后的网络通信算法,重点介绍基于TCP的文件传输机制,并提供Python代码示例。Win11共享文件夹利用SMB协议实现局域网内的文件共享,通过TCP协议确保文件传输的完整性和可靠性。服务器端监听客户端连接请求,接收文件请求并分块发送文件内容;客户端则连接服务器、接收数据并保存为本地文件。文中通过Python代码详细展示了这一过程,帮助读者理解并优化文件共享系统。
|
8天前
|
存储 算法 Python
文件管理系统中基于 Python 语言的二叉树查找算法探秘
在数字化时代,文件管理系统至关重要。本文探讨了二叉树查找算法在文件管理中的应用,并通过Python代码展示了其实现过程。二叉树是一种非线性数据结构,每个节点最多有两个子节点。通过文件名的字典序构建和查找二叉树,能高效地管理和检索文件。相较于顺序查找,二叉树查找每次比较可排除一半子树,极大提升了查找效率,尤其适用于海量文件管理。Python代码示例包括定义节点类、插入和查找函数,展示了如何快速定位目标文件。二叉树查找算法为文件管理系统的优化提供了有效途径。
40 5
|
2月前
|
数据采集 存储 算法
Python 中的数据结构和算法优化策略
Python中的数据结构和算法如何进行优化?
|
8天前
|
存储 缓存 算法
探索企业文件管理软件:Python中的哈希表算法应用
企业文件管理软件依赖哈希表实现高效的数据管理和安全保障。哈希表通过键值映射,提供平均O(1)时间复杂度的快速访问,适用于海量文件处理。在Python中,字典类型基于哈希表实现,可用于管理文件元数据、缓存机制、版本控制及快速搜索等功能,极大提升工作效率和数据安全性。
42 0