数据分析入门系列教程-KNN原理

简介: 数据分析入门系列教程-KNN原理

从今天开始,我们就进入正式的算法相关的学习了。在学习算法部分时,我希望你已经完全消化了前面所学习的内容,并能够熟练的掌握相关的知识了。


今天,我们来学习 KNN 算法。为什么要从 KNN 算法开始学习呢,因为这个算法是所有机器学习领域的算法中,是最简单,最易理解,最易实现的算法。我们从最简单的开始学习,慢慢锻炼相关的思维。


我们先来看下面的一个例子:

小 K 现在想应聘某家公司,很幸运,现在他的手里拿到了一些应聘者的数据和是否获得 offer 的情况。

image.png

可以看出,是否获得 offer,取决于这个人的工作经验以及当前的工资情况。而从小 K 所处的位置来看,还不是很好判断,他能否顺利获得 offer。

那么下面我们就通过 KNN 算法来帮助小 K 判断下,他距离 offer 到底有多远呢。


KNN工作原理


KNN 也就是 K 近邻算法,即某个待分类元素,查看它与其他类别元素的距离,统计距离最短的 K 个邻居,K 个中属于哪个类别的多,则该分类属于哪个类别。

下面,我们就通过上面的 offer 案例,来实际体验下 KNN 是如何工作的。

K = 1

当我们只取该元素的一个邻居作为参考时(K=1),就是查看与该元素距离最近的一个元素,它属于哪个类别,那么待分类的元素就属于哪个类别

image.png

如图中所示,与小 K 最近的邻居,是属于未获得 offfer 类别的,所以,在 K = 1 的情况下,小 K 是无法获得 offer 的。

K = 3

我们再来看看 K = 3 的情况,会不会发生变化呢

image.png

情况确实有所变化,当取3个近邻作为参考时,距离他最近的三个元素,有两个元素属于获得 offer 类别,有一个元素属于未获得 offer 类别,此时通过投票的机制,可以把小 K 划分到获得 offer 的一类中,小 K 又快乐的成为了可以获得 offer 的人。

K = 5

继续看当 k = 5 时,情况类似如下

image.png

在5个最近邻中,三个是获得 offer 的类别,两个是未获得 offer 的类别,同样可以投票选出,小 K 属于获得 offer 的类别。

以此类推,我们还可以查看 K = 7,9 等情况下,元素的分类状态。


KNN 计算过程

由上面的例子,我们可以得出 KNN 算法的基本步骤

  • 计算待分类元素与其他已分类元素的距离
  • 统计距离最近的 K 个邻居
  • 在 K 个邻居中,它们属于哪个分类多,待分类元素就属于哪个分类

看到了吧,整个 KNN 算法就是这样,是不是非常简单呢。不过你一定注意到了,随着 K 的取值不同,我们最终得到的结果也是不一样的,那么显而易见,K 的取值是十分重要的。


K值的选择


K 值很重要,那么 K 值选择多少合适呢?

首先,K 值一般都会选择为奇数,这个很好理解,如果是偶数,就有可能出现不同分类数量一致的情况,奇数的 K 值,就能够保证一定能选出最多的那个分类。

当然,这里针对的是二分类问题,K 设置为奇数可以有效的避免无法分类的情况,但是如果是多分类情况,还是无法有效的避免无法分类的情况发生。


那么对于 K 值的大小呢?

如果 K 值选择的比较小,那么如果它的邻居中存在噪声点,就会对元素的分类产生误差,也就是会有过拟合的风险,预测结果对于近邻的元素会非常敏感。

如果 K 值选择的比较大,即距离待分类元素很远的元素也能够影响分类,那么就很难把该元素真正的分类出来,从而导致预测错误,分类模糊,有欠拟合的风险。


假设 K 值为1

这将意味着未知的元素点的类别将由最近的1个已知样本点所决定。对于训练集,其误差几乎为0,但是在测试集当中,训练误差可能会非常大,因为最近的1个已知点可能是正常值,也可能是异常值。


假设 K 值为 N

这将意味着未知元素点的类别将由所有已知的样本点中频数最高的类别来决定。那么不论是训练集还是测试集,都会被判别为1种类别,这显然是不准确的,从而使得训练的模型无法正常识别未知样本的类别。


那么到底怎么选择 K 的取值呢?

当然有办法,业界一般会使用交叉验证(Cross Validation)的思维来选取 K 值。

何为交叉验证呢,就是把训练集进一步分成训练数据(Training Data)和验证数据(Validation Data),在训练数据上取不同的 K 值进行模型训练,然后在验证数据上做验证,最终选择在验证数据里最好的 K 值作为最终的 K 值。

image.png

现在,我们先把总样本数据分成训练集和测试集两部分,然后再把训练集分出一部分作为验证集。这样,在验证集中表现比较好的模型,就可以拿到测试集中做测试了。

到现在,我们已经知道了 KNN 的原理,还清楚了该如何选择 K 的取值,那么还有最后一个问题,待分类元素和已分类元素之间的距离该如何计算呢?


计算距离


两个样本点之间的距离代表了这两个样本之间的相似度。距离越大,差异性越大;距离越小,相似度越大。

而最常用的计算距离的方式,就是欧式距离。

如果空间中存在两个点 A(X1, Y1),B(X2, Y2),那么它们之间的直线距离为

image.png

image.png

如果将点的坐标扩展到 n 维空间

在空间中有 X 和 Y 两个点,其坐标分别为(X1,X2,X3…Xn)和(Y1,Y2,Y3…Yn),那么这两点之间的距离为:

image.png

如上,就是欧式距离的计算公式。

当然,距离的计算,还有很多其他的方式,比如曼哈顿距离,切比雪夫距离等等,这里就不再一一介绍了,感兴趣的同学可以自行查找学习。


手写KNN


既然 KNN 是最简单的算法,而且我们也明白了其工作原理,那么现在就尝试着动手来写一个 KNN 算法,其核心就是计算两点之间的距离。

在动手之前,我们先来学习一个 Python 的内置库 Counter

Counter 其实继承的是一个字典类,所以它是可以使用字典对应的方法的。该库主要的功能就是统计追踪值出现的次数。

from collections import Counter 
mytest = Counter("abcabcaba")
print(mytest)
>>>
Counter({'a': 4, 'b': 3, 'c': 2})

Counter 会统计出字符串中每个元素出现的次数

most_common 函数,用来选出出现次数最多的 n 个

print(mytest.most_common(2))
>>>
[('a', 4), ('b', 3)]

后面我们就可以使用该函数来选出 KNN 算法中最近的 K 个邻居

实现欧式距离

我们可以根据上面的欧式距离公式,两个点各维位置坐标相减的平方再开方。

import numpy as np
np.sqrt(sum((instance1 - instance2)**2))

NumPy 中的 sqrt 函数,就是计算非负实数的平方根函数。

再把上面的代码封装成函数

def cal_dis(instance1, instance2):
    dist = np.sqrt(sum((instance1 - instance2)**2))
    return dist

其中的 instance1 和 instance2 都代表空间中的点,都是 array 类型数据。

实现核心函数

下面就开始实现 KNN 算法的核心函数

首先计算需要测试的样本和已知样本所有数据的距离

distances = [euc_dis(x, testdata) for x in X]

X 就是已知样本的数据集,把从 X 种循环取出的 x 值与测试数据 testdata 计算距离,得到一个列表。

然后使用 NumPy 中的 argsort 函数进行排序,该函数会返回数据值从小到大的索引值,如:

test_data1 = np.array([2, 1, 5, 0])
print(np.argsort(test_data1))
>>>
[3 1 0 2]

可以看到,最小值是0,故其索引3排在第一位,接下来就是1(1),2(0)和5(2)

实现距离列表的从小到大排序

knbs = np.argsort(distances)[:k]

最后根据 K 的取值,截取前 K 个最小值。

接着再使用 Counter 函数,查看每一类出现的次数

count = Counter(y[knbs])
count.most_common()[0][0]

注意 knbs 是下标,故而 y[knbs] 才是样本值,然后使用 most_common 来提取出对应的分类。


最后封装成函数

def my_knn(X, y, testdata, k):
    distances = [cal_dis(x, testdata) for x in X]
    knbs= np.argsort(distances)[:k]
    count = Counter(y[knbs])
    return count.most_common()[0][0]


验证KNN


在著名的机器学习模块 Scikit-learn 当中也是包含 KNN 算法模型的,下面我们就通过两个模型来比较下,看看我们自己实现的 KNN 算法模型的表现怎么样。

pip3 install scikit-learn  # 安装


这里直接使用 sklearn 自带的 iris 数据集来做一个简单的比较

from sklearn import datasets  # 导入数据集
from collections import Counter
from sklearn.model_selection import train_test_split
import numpy as np


导入 iris 数据

# 导入iris数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=2002)

train_test_split 函数是一个切割训练集和测试集的函数,可以把整体数据集分割成一定比例的两部分,默认训练集和测试集的比例为3:1。


使用自写 KNN 代码做分类预测

predictions = [my_knn(X_train, y_train, data, 3) for data in X_test]
print("自写 KNN 准确率", accuracy_score(y_test, predictions))
>>>
自写 KNN 准确率 0.9210526315789473

accuracy_score 函数用来计算预测值与实际值之间的误差。


使用 sklearn 自带的 KNN 分类器做预测

knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)
predict_y = knn.predict(X_test)
print("KNN 准确率", accuracy_score(y_test, predict_y))
>>>
KNN 准确率 0.9210526315789473

可以看到,在这个数据集上,两种 KNN 分类器模型的表现是一样的,说明我们自行手写的简易版 KNN 算法还是可以的。

当然,sklearn 自带的 KNN 分类器还有很多重要的参数,性能方面也是做过调优的,具体内容我们下节在实战中再详细阐述。


KNN算法的优势


优点

  • KNN 算法简单易懂,精度高,理论成熟,既可以做分类也可以做回归
  • 既可以应用在数值型数据上,也可以应用在离散型数据上
  • 当 K 值选择的比较合理时,该算法对异常值会变的不敏感

缺点

  • 当样本数据过大时,计算量也随之增大
  • 在样本不平衡的数据集上表现不是很好(有些分类数据量大,有些分类数据量非常小)
  • 对于数据间的内在联系,是无法给出的


总结


今天我讲解了机器学习算法中最简单最易懂的 Hello World 级别的算法 KNN,在几乎不需要任何数学知识的情况下,我们就可以手动完成一个 KNN 算法的编写。


KNN 算法有几点是非常关键的,比如 K 值的选择,通常通过交叉验证的方式来选择。又比如最近邻距离的计算,最常用的计算距离的方式就是欧式距离。当然不同的场景,使用的距离计算方式也不尽相同。


KNN 算法虽然理论简单,易于实现,但是其缺点也很明显。当数据量过大时,计算量也是非常庞大的,需要大量的存储空间和计算时间。此外对于不平衡的样本,KNN 算法的准确率也会大大降低。

image.png

想一想,应该怎样通过交叉验证的方式来选择 iris 数据集例子中的 K 值?

相关文章
|
1月前
|
机器学习/深度学习 数据可视化 数据挖掘
使用Python进行数据分析的入门指南
本文将引导读者了解如何使用Python进行数据分析,从安装必要的库到执行基础的数据操作和可视化。通过本文的学习,你将能够开始自己的数据分析之旅,并掌握如何利用Python来揭示数据背后的故事。
|
2月前
|
机器学习/深度学习 数据可视化 数据挖掘
使用Python进行数据分析的入门指南
【10月更文挑战第42天】本文是一篇技术性文章,旨在为初学者提供一份关于如何使用Python进行数据分析的入门指南。我们将从安装必要的工具开始,然后逐步介绍如何导入数据、处理数据、进行数据可视化以及建立预测模型。本文的目标是帮助读者理解数据分析的基本步骤和方法,并通过实际的代码示例来加深理解。
71 3
|
2月前
|
数据可视化 数据挖掘
R中单细胞RNA-seq数据分析教程 (3)
R中单细胞RNA-seq数据分析教程 (3)
40 3
R中单细胞RNA-seq数据分析教程 (3)
|
2月前
|
SQL 数据挖掘 Python
R中单细胞RNA-seq数据分析教程 (1)
R中单细胞RNA-seq数据分析教程 (1)
42 5
R中单细胞RNA-seq数据分析教程 (1)
|
2月前
|
机器学习/深度学习 数据挖掘
R中单细胞RNA-seq数据分析教程 (2)
R中单细胞RNA-seq数据分析教程 (2)
53 0
R中单细胞RNA-seq数据分析教程 (2)
|
2月前
|
数据采集 数据可视化 数据挖掘
深入浅出:使用Python进行数据分析的基础教程
【10月更文挑战第41天】本文旨在为初学者提供一个关于如何使用Python语言进行数据分析的入门指南。我们将通过实际案例,了解数据处理的基本步骤,包括数据的导入、清洗、处理、分析和可视化。文章将用浅显易懂的语言,带领读者一步步掌握数据分析师的基本功,并在文末附上完整的代码示例供参考和实践。
|
3月前
|
数据采集 机器学习/深度学习 数据可视化
深入浅出:用Python进行数据分析的入门指南
【10月更文挑战第21天】 在信息爆炸的时代,掌握数据分析技能就像拥有一把钥匙,能够解锁隐藏在庞大数据集背后的秘密。本文将引导你通过Python语言,学习如何从零开始进行数据分析。我们将一起探索数据的收集、处理、分析和可视化等步骤,并最终学会如何利用数据讲故事。无论你是编程新手还是希望提升数据分析能力的专业人士,这篇文章都将为你提供一条清晰的学习路径。
|
3月前
|
数据挖掘 索引 Python
Python数据分析篇--NumPy--入门
Python数据分析篇--NumPy--入门
52 0
|
5月前
|
数据采集 数据可视化 数据挖掘
数据分析大神养成记:Python+Pandas+Matplotlib助你飞跃!
在数字化时代,数据分析至关重要,而Python凭借其强大的数据处理能力和丰富的库支持,已成为该领域的首选工具。Python作为基石,提供简洁语法和全面功能,适用于从数据预处理到高级分析的各种任务。Pandas库则像是神兵利器,其DataFrame结构让表格型数据的处理变得简单高效,支持数据的增删改查及复杂变换。配合Matplotlib这一数据可视化的魔法棒,能以直观图表展现数据分析结果。掌握这三大神器,你也能成为数据分析领域的高手!
97 2
|
2月前
|
机器学习/深度学习 算法 数据挖掘
数据分析的 10 个最佳 Python 库
数据分析的 10 个最佳 Python 库
102 4
数据分析的 10 个最佳 Python 库