kNN算法——帮你找到身边最相近的人-阿里云开发者社区

开发者社区> 人工智能> 正文

kNN算法——帮你找到身边最相近的人

简介: 本文简单介绍最近邻算法的基本思想以及具体python实现,并且分析了其优缺点及适用范围,适合初学者理解与动手实践。

       新生开学了,部分大学按照兴趣分配室友的新闻占据了头条,这其中涉及到机器学习算法的应用。此外,新生进入大学后,可能至少参加几个学生组织或社团。社团是根据学生的兴趣将它们分为不同的类别,那么如何定义这些类别,或者区分各个组织之间的差别呢?我敢肯定,如果你问过运营这些社团的人,他们肯定不会说他们的社团和其它的社团相同,但在某种程度上是相似的。比如,老乡会和高中同学会都有着同样的生活方式;足球俱乐部和羽毛球协会对运动有着相同的兴趣;科技创新协会和创业俱乐部有相近的的兴趣等。也许让你去衡量这些社团或组织所处理的事情或运行模式,你自己就可以确定哪些社团是自己感兴趣的。但有一种算法能够帮助你更好地做出决策,那就是k-Nearest Neighbors(NN)算法, 本文将使用学生社团来解释k-NN算法的一些概念,该算法可以说是最简单的机器学习算法,构建的模型仅包含存储的训练数据集。该算法对新数据点进行预测,就是在训练数据集中找到最接近的数据点——其“最近邻居”。

0_jpeg

工作原理

       在其最简单的版本中,k-NN算法仅考虑一个最近邻居,这个最近邻居就是我们想要预测点的最近训练数据点。然后,预测结果就是该训练点的输出。下图说明构造的数据集分类情况。

1

       从图中可以看到,我们添加了三个新的数据点,用星星表示。对于三个点中的每一点,我们都标记了训练集中离其最近的点,最近邻算法的预测输出就是标记的这点(用交叉颜色进行表示)。
       同样,我们也可以考虑任意数量k个邻居,而不是只考虑一个最近的邻居。这是k-NN算法名称的由来。在考虑多个邻居时,我们使用投票的方式来分配标签。这意味着对于每个测试点,我们计算有多少个邻居属于0类以及有多少个邻居属于1类。然后我们统计这些近邻中属于哪一类占的比重大就将预测点判定为哪一类:换句话说,少数服从多数。以下示例使用了5个最近的邻居:

2


       同样,将预测结果用交叉的颜色表示。从图中可以看到,左上角的新数据点的预测与我们仅使用一个最近邻居时的预测结果不相同。
       虽然此图仅展示了用于二分类的问题,但此方法可应用于具有任意数量类的数据集。对于多分类问题,同样计算k个邻居属于哪些类,并进行数量统计,从中选取数量最多的类作为预测结果。

Scratch实现k-NN算法

以下是k-NN算法的伪代码,用于对一个数据点进行分类(将其称为A点):
对于数据集中的每一个点:

  • 首先,计算A点和当前点之间的距离;
  • 然后,按递增顺序对距离进行排序;
  • 其次,把距离最近的k个点作为A的最近邻;
  • 之后,找到这些邻居中的绝大多数类;
  • 最后,将绝大多数类返回作为我们对A类的预测;

Python实现代码如下:

def knnclassify(A, dataset, labels, k):
  datasetSize = dataset.shape[0]
  
  # 计算A点和当前点之间的距离
  diffMat = tile(A, (datasetSize, 1)) - dataset
  sqDiffMat = diffMat ** 2
  sqDistances = sqDiffMat.sum(axis=1)
  distances = sqDistances ** 0.5
  
  # 按照增序对距离排序
  sortedDistIndices = distances.argsort()
  
  # 选出距离最小的k个点
  classCount = {}
  for i in range(k):
    voteIlabel = labels[sortedDistIndices[i]]
    classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    
  # 对这些点所处的类别按照频次排序
  sortedClassCount = sorted(classCount.iteritem(), key=operator.itemgetter(1), reverse=True)
  
  return sortedClassCount[0][0]

下面让我们深入研究下上述代码:

  • 函数knnclassify需要4个输入参数:要分类的输入向量称为A,称为dataSet的训练样例的完整矩阵,称为labels的标签向量,以及k——在投票中使用的最近邻居的数量。
  • 使用欧几里德距离计算A和当前点之间的距离。
  • 按照递增顺序对距离进行排序。
  • 从中选出k个最近距离来对A类进行投票。
  • 之后,获取classCount字典并将其分解为元组列表,然后按元组中的第2项对元组进行排序。由于排序的顺序是相反的,因此我们选择从最大到最小(设置reverse)。
  • 最后,返回最频繁出现的类别标签。

Scikit-Learn实现k-NN算法

       Scikit-Learn是一个机器学习工具箱,内部集成了很多机器学习算法。现在让我们看一下如何使用Scikit-learn实现kNN算法。代码如下:

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 导入iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 将其按照一定的比例划分为训练集和测试集(random_state=0 保证每次运行分割得到一样的训练集和测试集)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 设定邻居个数 
clf = KNeighborsClassifier(n_neighbors=5)

# 拟合训练数据 
clf.fit(X_train, y_train)

# 对测试集进行预测 
predictions = clf.predict(X_test)
print("Test set predictions: {}".format(predictions))

# 评估模型性能
accuracy = clf.score(X_test, y_test)
print("Test set accuracy: {:.2f}".format(accuracy))

下面让我们来看看上述代码:

  • 首先,生成鸢尾属植物数据集;
  • 然后,将数据拆分为训练和测试集,以评估泛化性能;
  • 之后,将邻居数量(k)指定为5;
  • 接下来,使用训练集来拟合分类器;
  • 为了对测试数据进行预测,对于测试集中的每个数据点,都要使用该方法计算训练集中的最近邻居,并找到其中最频繁出现的类;
  • 最后,通过使用测试数据和测试标签调用score函数来评估模型的泛化能力;

       模型运行完毕,测试集上得到97%的准确度,这意味着模型在测试数据集中97%的样本都正确地预测出类别;

3

优点和缺点

       一般而言,k-NN分类器有两个重要参数:邻居数量以及数据点之间的距离计算方式。

  • 在实践应用中,一般使用少数3个或5个邻居时效果通常会很好。当然,应该根据具体情况调整这个参数;
  • 选择正确的距离测量方法可能有些困难。一般情况下,都是使用欧几里德距离,欧几里得距离在许多设置中效果都不错;

       k-NN的优势之一是该模型非常易于理解,并且通常无需进行大量参数调整的情况下就能获得比较不错的性能表现。在考虑使用更高级的技术之前,使用此算法是一种很好的基线方法。k-NN模型的建立通常会比较快,但是当训练集非常大时(无论是特征数还是样本数量),预测时耗费的时间会很多。此外,使用k-NN算法时,对数据进行预处理非常重要。该方法通常在具有许多特征(数百或更多)的数据集上表现不佳,并且对于大多数特征在大多数情况下为0的数据集(所谓的稀疏数据集)而言尤其糟糕。

结论

       k-NN算法是一种简单有效的数据分类方法,它是基于实例学习的一种机器学习算法,需要通过数据实例来执行机器学习算法,该算法必须携带完整的数据集。而对于大型的数据集,需要耗费比较大的存储。此外,还需要计算数据库中每个数据点距离预测点的的距离,这个过程会很麻烦,且耗时多。另一个缺点是k-NN算法不能够让你了解数据的基础结构,无法知道每个类别的“平均”或“范例”具体是什么样子。
       因此,虽然k-NN算法易于理解,但由于预测速度慢且无法处理多特征问题,因此在实践中并不常用。

参考资料

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

作者信息

James Le,机器学习工程师
LinkedIn:http://www.linkedin.com/in/khanhnamle94
本文由阿里云云栖社区组织翻译。
文章原标题《k-Nearest Neighbors: Who are close to you》,译者:海棠,审校:Uncle_LLD。
文章为简译,更为详细的内容,请查看原文

版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

分享:
人工智能
使用钉钉扫一扫加入圈子
+ 订阅

了解行业+人工智能最先进的技术和实践,参与行业+人工智能实践项目

其他文章