机器学习第14天:KNN近邻算法

简介: 机器学习第14天:KNN近邻算法



介绍

KNN算法的核心思想是:当我们要判断一个数据为哪一类时,我们找与它相近的一些数据,以这些数据的类别来判断新数据

实例

我们生成一些数据,看下面这张图

有两类点,红色与蓝色,这时我们再加入一个灰色的点

我们设置模型选择周围的三个点,可以看到最近的三个都是蓝色点,那么模型就会将新的数据判别为蓝色点


回归任务

尽管KNN算法主要用来做分类任务,但它也可以用来回归,新数据的值就是相近样本的平均值

缺点

由于它没有拟合参数,仅仅是找到周围样本点的平均值,在一些有趋势的曲线中它的预测往往不会很好

实例

我们创建几个样本点,可以看到这是一个完美的线性曲线,我们看看k近邻算法在这个简单任务上的表现

# 导入必要的库
from sklearn.neighbors import KNeighborsRegressor
 
# 生成一些示例数据(假设是二维特征)
X = [[1], [2], [3], [4], [5]]
y = [[3], [6], [9], [12], [15]]
 
x_new = [[6]]
 
# 创建 KNN 回归器,假设 K=3
knn = KNeighborsRegressor(n_neighbors=3)
 
# 在训练数据上拟合模型
knn.fit(X, y)
 
# 在测试数据上进行预测
y_pred = knn.predict(x_new)
 
print(y_pred)

在这个数据集上x为6的点y值应该是18,可是k近邻回归的特点取周围样本点的平均值,结果就会是12


分类任务

我们以上图的数据为例

# 导入KNN分类库
from sklearn.neighbors import KNeighborsClassifier
 
 
# 生成一些示例数据
X = [[1, 8], [2, 5], [3, 7], [5, 13], [6, 11], [7, 14]]
y = [0, 0, 0, 1, 1, 1]
 
x_new = [[6, 12]]
 
# 创建 KNN 分类器,设置k=3
knn = KNeighborsClassifier(n_neighbors=3)
 
# 在训练数据上拟合模型
knn.fit(X, y)
 
# 进行预测
y_pred = knn.predict(x_new)
 
print(y_pred)

n_neighbors参数设置了新数据要参考周围的多少个点,这里设置为3,代表参考相近的三个点的值

结果为1


如何选择最佳参数

由以上知识可以知道,影响KNN算法的参数是n_neighbors,那么我们可以更新n_neighbors,然后记录下每个参数模型在测试集上的损失来获得最优参数

绘制代码如下,这里主要学习思想,数据可能会在之后的机器学习实战系列中遇到

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score
import pandas as pd
import numpy as np
 
# 读取数据
data = pd.read_csv("datasets/data-science-london-scikit-learn/train.csv", header=None)
y = pd.read_csv("datasets/data-science-london-scikit-learn/trainLabels.csv", header=None)
y = np.ravel(y)
 
# 将数据分为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(data, y, test_size=0.2, random_state=42)
 
N = range(2, 26)
kfold = 10
test_acc = []
val_acc = []
 
# 记录不同参数的准确率
for n in N:
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(x_train, y_train)
    test_acc.append(knn.score(x_train, y_train))
    val_acc.append(np.mean(cross_val_score(knn, x_test, y_test, cv=kfold)))
 
 
# 绘制准确率曲线
plt.plot(range(2, 26), test_acc, c='b', label='test_acc')
plt.plot(range(2, 26), val_acc, c='r', label='val_acc')
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.title('K Neighbors vs Accuracy')
plt.legend()
plt.show()
 

得到准确率与交叉验证误差曲线,

可以看到n_neighbors=5时模型的准确率最好,我们最后就可以使用这个参数


结语

  • k近邻算法几乎没有训练过程,它只需要记住训练集的特征就行,以便之后进行比较,它不需要拟合什么参数
  • 可以绘制准确率曲线来找到最好的k值
  • 可以进行回归任务,但在模型情况下效果不是很好

感谢阅读,觉得有用的话就订阅下本专栏吧

相关文章
|
3天前
|
机器学习/深度学习 算法 数据处理
探索机器学习中的决策树算法
【5月更文挑战第18天】探索机器学习中的决策树算法,一种基于树形结构的监督学习,常用于分类和回归。算法通过递归划分数据,选择最优特征以提高子集纯净度。优点包括直观、高效、健壮和可解释,但易过拟合、对连续数据处理不佳且不稳定。广泛应用于信贷风险评估、医疗诊断和商品推荐等领域。优化方法包括集成学习、特征工程、剪枝策略和参数调优。
|
4天前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】K-means算法与PCA算法之间有什么联系?
【5月更文挑战第15天】【机器学习】K-means算法与PCA算法之间有什么联系?
|
4天前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】维度灾难问题会如何影响K-means算法?
【5月更文挑战第15天】【机器学习】维度灾难问题会如何影响K-means算法?
|
5天前
|
机器学习/深度学习 算法 数据挖掘
【机器学习】聚类算法中,如何判断数据是否被“充分”地聚类,以便算法产生有意义的结果?
【5月更文挑战第14天】【机器学习】聚类算法中,如何判断数据是否被“充分”地聚类,以便算法产生有意义的结果?
|
5天前
|
机器学习/深度学习 运维 算法
【机器学习】可以利用K-means算法找到数据中的离群值吗?
【5月更文挑战第14天】【机器学习】可以利用K-means算法找到数据中的离群值吗?
|
6天前
|
算法 数据安全/隐私保护 计算机视觉
基于二维CS-SCHT变换和LABS方法的水印嵌入和提取算法matlab仿真
该内容包括一个算法的运行展示和详细步骤,使用了MATLAB2022a。算法涉及水印嵌入和提取,利用LAB色彩空间可能用于隐藏水印。水印通过二维CS-SCHT变换、低频系数处理和特定解码策略来提取。代码段展示了水印置乱、图像处理(如噪声、旋转、剪切等攻击)以及水印的逆置乱和提取过程。最后,计算并保存了比特率,用于评估水印的稳健性。
|
3天前
|
算法
m基于BP译码算法的LDPC编译码matlab误码率仿真,对比不同的码长
MATLAB 2022a仿真实现了LDPC码的性能分析,展示了不同码长对纠错能力的影响。短码长LDPC码收敛快但纠错能力有限,长码长则提供更强纠错能力但易陷入局部最优。核心代码通过循环进行误码率仿真,根据EsN0计算误比特率,并保存不同码长(12-768)的结果数据。
21 9
m基于BP译码算法的LDPC编译码matlab误码率仿真,对比不同的码长
|
4天前
|
算法
MATLAB|【免费】融合正余弦和柯西变异的麻雀优化算法SCSSA-CNN-BiLSTM双向长短期记忆网络预测模型
这段内容介绍了一个使用改进的麻雀搜索算法优化CNN-BiLSTM模型进行多输入单输出预测的程序。程序通过融合正余弦和柯西变异提升算法性能,主要优化学习率、正则化参数及BiLSTM的隐层神经元数量。它利用一段简单的风速数据进行演示,对比了改进算法与粒子群、灰狼算法的优化效果。代码包括数据导入、预处理和模型构建部分,并展示了优化前后的效果。建议使用高版本MATLAB运行。
|
6天前
|
算法 计算机视觉
基于高斯混合模型的视频背景提取和人员跟踪算法matlab仿真
该内容是关于使用MATLAB2013B实现基于高斯混合模型(GMM)的视频背景提取和人员跟踪算法。算法通过GMM建立背景模型,新帧与模型比较,提取前景并进行人员跟踪。文章附有程序代码示例,展示从读取视频到结果显示的流程。最后,结果保存在Result.mat文件中。
|
6天前
|
资源调度 算法 块存储
m基于遗传优化的LDPC码OMS译码算法最优偏移参数计算和误码率matlab仿真
MATLAB2022a仿真实现了遗传优化的LDPC码OSD译码算法,通过自动搜索最佳偏移参数ΔΔ以提升纠错性能。该算法结合了低密度奇偶校验码和有序统计译码理论,利用遗传算法进行全局优化,避免手动调整,提高译码效率。核心程序包括编码、调制、AWGN信道模拟及软输入软输出译码等步骤,通过仿真曲线展示了不同SNR下的误码率性能。
10 1

热门文章

最新文章