机器学习第14天:KNN近邻算法

简介: 机器学习第14天:KNN近邻算法



介绍

KNN算法的核心思想是:当我们要判断一个数据为哪一类时,我们找与它相近的一些数据,以这些数据的类别来判断新数据

实例

我们生成一些数据,看下面这张图

有两类点,红色与蓝色,这时我们再加入一个灰色的点

我们设置模型选择周围的三个点,可以看到最近的三个都是蓝色点,那么模型就会将新的数据判别为蓝色点


回归任务

尽管KNN算法主要用来做分类任务,但它也可以用来回归,新数据的值就是相近样本的平均值

缺点

由于它没有拟合参数,仅仅是找到周围样本点的平均值,在一些有趋势的曲线中它的预测往往不会很好

实例

我们创建几个样本点,可以看到这是一个完美的线性曲线,我们看看k近邻算法在这个简单任务上的表现

# 导入必要的库
from sklearn.neighbors import KNeighborsRegressor
 
# 生成一些示例数据(假设是二维特征)
X = [[1], [2], [3], [4], [5]]
y = [[3], [6], [9], [12], [15]]
 
x_new = [[6]]
 
# 创建 KNN 回归器,假设 K=3
knn = KNeighborsRegressor(n_neighbors=3)
 
# 在训练数据上拟合模型
knn.fit(X, y)
 
# 在测试数据上进行预测
y_pred = knn.predict(x_new)
 
print(y_pred)

在这个数据集上x为6的点y值应该是18,可是k近邻回归的特点取周围样本点的平均值,结果就会是12


分类任务

我们以上图的数据为例

# 导入KNN分类库
from sklearn.neighbors import KNeighborsClassifier
 
 
# 生成一些示例数据
X = [[1, 8], [2, 5], [3, 7], [5, 13], [6, 11], [7, 14]]
y = [0, 0, 0, 1, 1, 1]
 
x_new = [[6, 12]]
 
# 创建 KNN 分类器,设置k=3
knn = KNeighborsClassifier(n_neighbors=3)
 
# 在训练数据上拟合模型
knn.fit(X, y)
 
# 进行预测
y_pred = knn.predict(x_new)
 
print(y_pred)

n_neighbors参数设置了新数据要参考周围的多少个点,这里设置为3,代表参考相近的三个点的值

结果为1


如何选择最佳参数

由以上知识可以知道,影响KNN算法的参数是n_neighbors,那么我们可以更新n_neighbors,然后记录下每个参数模型在测试集上的损失来获得最优参数

绘制代码如下,这里主要学习思想,数据可能会在之后的机器学习实战系列中遇到

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, cross_val_score
import pandas as pd
import numpy as np
 
# 读取数据
data = pd.read_csv("datasets/data-science-london-scikit-learn/train.csv", header=None)
y = pd.read_csv("datasets/data-science-london-scikit-learn/trainLabels.csv", header=None)
y = np.ravel(y)
 
# 将数据分为训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(data, y, test_size=0.2, random_state=42)
 
N = range(2, 26)
kfold = 10
test_acc = []
val_acc = []
 
# 记录不同参数的准确率
for n in N:
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(x_train, y_train)
    test_acc.append(knn.score(x_train, y_train))
    val_acc.append(np.mean(cross_val_score(knn, x_test, y_test, cv=kfold)))
 
 
# 绘制准确率曲线
plt.plot(range(2, 26), test_acc, c='b', label='test_acc')
plt.plot(range(2, 26), val_acc, c='r', label='val_acc')
plt.xlabel('Number of Neighbors')
plt.ylabel('Accuracy')
plt.title('K Neighbors vs Accuracy')
plt.legend()
plt.show()
 

得到准确率与交叉验证误差曲线,

可以看到n_neighbors=5时模型的准确率最好,我们最后就可以使用这个参数


结语

  • k近邻算法几乎没有训练过程,它只需要记住训练集的特征就行,以便之后进行比较,它不需要拟合什么参数
  • 可以绘制准确率曲线来找到最好的k值
  • 可以进行回归任务,但在模型情况下效果不是很好

感谢阅读,觉得有用的话就订阅下本专栏吧

相关文章
|
21天前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
65 4
|
17天前
|
机器学习/深度学习 算法 数据挖掘
C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出
本文探讨了C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出。文章还介绍了C语言在知名机器学习库中的作用,以及与Python等语言结合使用的案例,展望了其未来发展的挑战与机遇。
37 1
|
26天前
|
机器学习/深度学习 自然语言处理 算法
深入理解机器学习算法:从线性回归到神经网络
深入理解机器学习算法:从线性回归到神经网络
|
1月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
77 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
26天前
|
机器学习/深度学习 算法
深入探索机器学习中的决策树算法
深入探索机器学习中的决策树算法
34 0
|
27天前
|
机器学习/深度学习 算法 Python
机器学习入门:理解并实现K-近邻算法
机器学习入门:理解并实现K-近邻算法
32 0
|
12天前
|
算法
基于WOA算法的SVDD参数寻优matlab仿真
该程序利用鲸鱼优化算法(WOA)对支持向量数据描述(SVDD)模型的参数进行优化,以提高数据分类的准确性。通过MATLAB2022A实现,展示了不同信噪比(SNR)下模型的分类误差。WOA通过模拟鲸鱼捕食行为,动态调整SVDD参数,如惩罚因子C和核函数参数γ,以寻找最优参数组合,增强模型的鲁棒性和泛化能力。
|
18天前
|
机器学习/深度学习 算法 Serverless
基于WOA-SVM的乳腺癌数据分类识别算法matlab仿真,对比BP神经网络和SVM
本项目利用鲸鱼优化算法(WOA)优化支持向量机(SVM)参数,针对乳腺癌早期诊断问题,通过MATLAB 2022a实现。核心代码包括参数初始化、目标函数计算、位置更新等步骤,并附有详细中文注释及操作视频。实验结果显示,WOA-SVM在提高分类精度和泛化能力方面表现出色,为乳腺癌的早期诊断提供了有效的技术支持。
|
6天前
|
存储 算法
基于HMM隐马尔可夫模型的金融数据预测算法matlab仿真
本项目基于HMM模型实现金融数据预测,包括模型训练与预测两部分。在MATLAB2022A上运行,通过计算状态转移和观测概率预测未来值,并绘制了预测值、真实值及预测误差的对比图。HMM模型适用于金融市场的时间序列分析,能够有效捕捉隐藏状态及其转换规律,为金融预测提供有力工具。
|
6天前
|
机器学习/深度学习 算法 信息无障碍
基于GoogleNet深度学习网络的手语识别算法matlab仿真
本项目展示了基于GoogleNet的深度学习手语识别算法,使用Matlab2022a实现。通过卷积神经网络(CNN)识别手语手势,如"How are you"、"I am fine"、"I love you"等。核心在于Inception模块,通过多尺度处理和1x1卷积减少计算量,提高效率。项目附带完整代码及操作视频。