【Python机器学习】实验07 KNN最近邻算法1

简介: 【Python机器学习】实验07 KNN最近邻算法1

KNN算法

1.k kk近邻法是基本且简单的分类与回归方法。k kk近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的k kk个最近邻训练实例点,然后利用这k kk个训练实例点的类的多数来预测输入实例点的类。


2.k kk近邻模型对应于基于训练数据集对特征空间的一个划分。k kk近邻法中,当训练集、距离度量、k kk值及分类决策规则确定后,其结果唯一确定,没有近似,他没有学习参数。


3.k kk近邻法三要素:距离度量、k kk值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。k kk值小时,k kk近邻模型更复杂;k kk值大时,k kk近邻模型更简单。k kk值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k kk。


常用的分类决策规则是多数表决,对应于经验风险最小化。


4.k kk近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对k kk维空间的一个划分,其每个结点对应于k kk维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。

前言 距离度量

在机器学习算法中,我们经常需要计算样本之间的相似度,通常的做法是计算样本之间的距离。


设x xx和y yy为两个向量,求它们之间的距离。


这里用Numpy实现,设和为ndarray <numpy.ndarray>,它们的shape都是(N,)


d dd为所求的距离,是个浮点数(float)。

1) 欧式距离

欧几里得度量(euclidean metric)(也称欧氏距离)是一个通常采用的距离定义,指在m维空间中两个点之间的真实距离,或者向量的自然长度(即该点到原点的距离)。在二维和三维空间中的欧氏距离就是两点之间的实际距离。


距离公式:

f09cb31da7ef8cce6c0de022c66a909.png

代码实现:

def euclidean(x, y):
    return np.sqrt(np.sum((x - y)**2))


(2) 曼哈顿距离(Manhattan distance)

想象你在城市道路里,要从一个十字路口开车到另外一个十字路口,驾驶距离是两点间的直线距离吗?显然不是,除非你能穿越大楼。实际驾驶距离就是这个“曼哈顿距离”。而这也是曼哈顿距离名称的来源,曼哈顿距离也称为城市街区距离(City Block distance)。


距离公式:

29a8a308b688cb5fc4000ad244ede26.png

代码实现:

def manhatan_distance(x,y):
    return np.sum(np.abs(x-y))

(3) 切比雪夫距离(Chebyshev distance)

在数学中,切比雪夫距离(Chebyshev distance)或是L∞度量,是向量空间中的一种度量,二个点之间的距离定义是其各坐标数值差绝对值的最大值。以数学的观点来看,切比雪夫距离是由一致范数(uniform norm)(或称为上确界范数)所衍生的度量,也是超凸度量(injective metric space)的一种。


距离公式:

5265b6f72bb99a0f1f40a587dc73a66.png

若将国际象棋棋盘放在二维直角座标系中,格子的边长定义为1,座标的x xx轴及y yy轴和棋盘方格平行,原点恰落在某一格的中心点,则王从一个位置走到其他位置需要的步数恰为二个位置的切比雪夫距离,因此切比雪夫距离也称为棋盘距离。例如位置F6和位置E2的切比雪夫距离为4。任何一个不在棋盘边缘的位置,和周围八个位置的切比雪夫距离都是1。

代码实现:

def chebysev_distance(x,y):
    return np.max(np.abs(x-y))

(4) 闵可夫斯基距离(Minkowski distance)

闵氏空间指狭义相对论中由一个时间维和三个空间维组成的时空,为俄裔德国数学家闵可夫斯基(H.Minkowski,1864-1909)最先表述。他的平坦空间(即假设没有重力,曲率为零的空间)的概念以及表示为特殊距离量的几何学是与狭义相对论的要求相一致的。闵可夫斯基空间不同于牛顿力学的平坦空间。p pp取1或2时的闵氏距离是最为常用的,p = 2 p= 2p=2即为欧氏距离,而p = 1 p =1p=1时则为曼哈顿距离。


当p pp取无穷时的极限情况下,可以得到切比雪夫距离。


距离公式:

28673f225e80cd873e70a3e57333724.png

代码实现:

def minkowski(x, y, p):
    return np.sum(np.abs(x - y)**p)**(1 / p)

(5) 汉明距离(Hamming distance)

汉明距离是使用在数据传输差错控制编码里面的,汉明距离是一个概念,它表示两个(相同长度)字对应位不同的数量,我们以表示两个字,之间的汉明距离。对两个字符串进行异或运算,并统计结果为1的个数,那么这个数就是汉明距离。


距离公式:

80195ec9dbec4db23efe6b83619485c.png

def hamming(x,y):
    return np.sum(x!=y)/len(x)

(6) 余弦相似度(Cosine Similarity)

余弦相似性通过测量两个向量的夹角的余弦值来度量它们之间的相似性。0度角的余弦值是1,而其他任何角度的余弦值都不大于1;并且其最小值是-1。从而两个向量之间的角度的余弦值确定两个向量是否大致指向相同的方向。两个向量有相同的指向时,余弦相似度的值为1;两个向量夹角为90°时,余弦相似度的值为0;两个向量指向完全相反的方向时,余弦相似度的值为-1。这结果是与向量的长度无关的,仅仅与向量的指向方向相关。余弦相似度通常用于正空间,因此给出的值为0到1之间。


二维空间为例,上图的a aa和b bb是两个向量,我们要计算它们的夹角θ。余弦定理告诉我们,可以用下面的公式求得:


953f2b45a37bef4668b31dded22ae9b.png

953f2b45a37bef4668b31dded22ae9b.png

代码实现:

def square_rooted(x):
    return np.sqrt(np.sum(np.power(x,2)))
def cosine_similarity_distance(x,y):
    fenzi=np.sum(np.multiply(x,y))
    fenmu=square_rooted(x)*square_rooted(y)
    return fenzi/fenmu
import numpy as np
print(cosine_similarity_distance([3, 45, 7, 2], [2, 54, 13, 15]))
0.9722842517123499

KNN算法介绍

1.k kk近邻法是基本且简单的分类与回归方法。k kk近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的k kk个最近邻训练实例点,然后利用这k kk个训练实例点的类的多数来预测输入实例点的类。


2.k kk近邻模型对应于基于训练数据集对特征空间的一个划分。k kk近邻法中,当训练集、距离度量、k kk值及分类决策规则确定后,其结果唯一确定。


3.k kk近邻法三要素:距离度量、k kk值的选择和分类决策规则。常用的距离度量是欧氏距离。k kk值小时,k kk近邻模型更复杂;k kk值大时,k kk近邻模型更简单。k kk值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k kk。


常用的分类决策规则是多数表决,对应于经验风险最小化。


4.k kk近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对k kk维空间的一个划分,其每个结点对应于k kk维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。


python实现,遍历所有数据点,找出n nn个距离最近的点的分类情况,少数服从多数


1 数据的准备

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter

导入鸢尾花数据集

iris = load_iris()
iris
{'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        [5.4, 3.9, 1.7, 0.4],
        [4.6, 3.4, 1.4, 0.3],
        [5. , 3.4, 1.5, 0.2],
        [4.4, 2.9, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.1],
        [5.4, 3.7, 1.5, 0.2],
        [4.8, 3.4, 1.6, 0.2],
        [4.8, 3. , 1.4, 0.1],
        [4.3, 3. , 1.1, 0.1],
        [5.8, 4. , 1.2, 0.2],
        [5.7, 4.4, 1.5, 0.4],
        [5.4, 3.9, 1.3, 0.4],
        [5.1, 3.5, 1.4, 0.3],
        [5.7, 3.8, 1.7, 0.3],
        [5.1, 3.8, 1.5, 0.3],
        [5.4, 3.4, 1.7, 0.2],
        [5.1, 3.7, 1.5, 0.4],
        [4.6, 3.6, 1. , 0.2],
        [5.1, 3.3, 1.7, 0.5],
        [4.8, 3.4, 1.9, 0.2],
        [5. , 3. , 1.6, 0.2],
        [5. , 3.4, 1.6, 0.4],
        [5.2, 3.5, 1.5, 0.2],
        [5.2, 3.4, 1.4, 0.2],
        [4.7, 3.2, 1.6, 0.2],
        [4.8, 3.1, 1.6, 0.2],
        [5.4, 3.4, 1.5, 0.4],
        [5.2, 4.1, 1.5, 0.1],
        [5.5, 4.2, 1.4, 0.2],
        [4.9, 3.1, 1.5, 0.2],
        [5. , 3.2, 1.2, 0.2],
        [5.5, 3.5, 1.3, 0.2],
        [4.9, 3.6, 1.4, 0.1],
        [4.4, 3. , 1.3, 0.2],
        [5.1, 3.4, 1.5, 0.2],
        [5. , 3.5, 1.3, 0.3],
        [4.5, 2.3, 1.3, 0.3],
        [4.4, 3.2, 1.3, 0.2],
        [5. , 3.5, 1.6, 0.6],
        [5.1, 3.8, 1.9, 0.4],
        [4.8, 3. , 1.4, 0.3],
        [5.1, 3.8, 1.6, 0.2],
        [4.6, 3.2, 1.4, 0.2],
        [5.3, 3.7, 1.5, 0.2],
        [5. , 3.3, 1.4, 0.2],
        [7. , 3.2, 4.7, 1.4],
        [6.4, 3.2, 4.5, 1.5],
        [6.9, 3.1, 4.9, 1.5],
        [5.5, 2.3, 4. , 1.3],
        [6.5, 2.8, 4.6, 1.5],
        [5.7, 2.8, 4.5, 1.3],
        [6.3, 3.3, 4.7, 1.6],
        [4.9, 2.4, 3.3, 1. ],
        [6.6, 2.9, 4.6, 1.3],
        [5.2, 2.7, 3.9, 1.4],
        [5. , 2. , 3.5, 1. ],
        [5.9, 3. , 4.2, 1.5],
        [6. , 2.2, 4. , 1. ],
        [6.1, 2.9, 4.7, 1.4],
        [5.6, 2.9, 3.6, 1.3],
        [6.7, 3.1, 4.4, 1.4],
        [5.6, 3. , 4.5, 1.5],
        [5.8, 2.7, 4.1, 1. ],
        [6.2, 2.2, 4.5, 1.5],
        [5.6, 2.5, 3.9, 1.1],
        [5.9, 3.2, 4.8, 1.8],
        [6.1, 2.8, 4. , 1.3],
        [6.3, 2.5, 4.9, 1.5],
        [6.1, 2.8, 4.7, 1.2],
        [6.4, 2.9, 4.3, 1.3],
        [6.6, 3. , 4.4, 1.4],
        [6.8, 2.8, 4.8, 1.4],
        [6.7, 3. , 5. , 1.7],
        [6. , 2.9, 4.5, 1.5],
        [5.7, 2.6, 3.5, 1. ],
        [5.5, 2.4, 3.8, 1.1],
        [5.5, 2.4, 3.7, 1. ],
        [5.8, 2.7, 3.9, 1.2],
        [6. , 2.7, 5.1, 1.6],
        [5.4, 3. , 4.5, 1.5],
        [6. , 3.4, 4.5, 1.6],
        [6.7, 3.1, 4.7, 1.5],
        [6.3, 2.3, 4.4, 1.3],
        [5.6, 3. , 4.1, 1.3],
        [5.5, 2.5, 4. , 1.3],
        [5.5, 2.6, 4.4, 1.2],
        [6.1, 3. , 4.6, 1.4],
        [5.8, 2.6, 4. , 1.2],
        [5. , 2.3, 3.3, 1. ],
        [5.6, 2.7, 4.2, 1.3],
        [5.7, 3. , 4.2, 1.2],
        [5.7, 2.9, 4.2, 1.3],
        [6.2, 2.9, 4.3, 1.3],
        [5.1, 2.5, 3. , 1.1],
        [5.7, 2.8, 4.1, 1.3],
        [6.3, 3.3, 6. , 2.5],
        [5.8, 2.7, 5.1, 1.9],
        [7.1, 3. , 5.9, 2.1],
        [6.3, 2.9, 5.6, 1.8],
        [6.5, 3. , 5.8, 2.2],
        [7.6, 3. , 6.6, 2.1],
        [4.9, 2.5, 4.5, 1.7],
        [7.3, 2.9, 6.3, 1.8],
        [6.7, 2.5, 5.8, 1.8],
        [7.2, 3.6, 6.1, 2.5],
        [6.5, 3.2, 5.1, 2. ],
        [6.4, 2.7, 5.3, 1.9],
        [6.8, 3. , 5.5, 2.1],
        [5.7, 2.5, 5. , 2. ],
        [5.8, 2.8, 5.1, 2.4],
        [6.4, 3.2, 5.3, 2.3],
        [6.5, 3. , 5.5, 1.8],
        [7.7, 3.8, 6.7, 2.2],
        [7.7, 2.6, 6.9, 2.3],
        [6. , 2.2, 5. , 1.5],
        [6.9, 3.2, 5.7, 2.3],
        [5.6, 2.8, 4.9, 2. ],
        [7.7, 2.8, 6.7, 2. ],
        [6.3, 2.7, 4.9, 1.8],
        [6.7, 3.3, 5.7, 2.1],
        [7.2, 3.2, 6. , 1.8],
        [6.2, 2.8, 4.8, 1.8],
        [6.1, 3. , 4.9, 1.8],
        [6.4, 2.8, 5.6, 2.1],
        [7.2, 3. , 5.8, 1.6],
        [7.4, 2.8, 6.1, 1.9],
        [7.9, 3.8, 6.4, 2. ],
        [6.4, 2.8, 5.6, 2.2],
        [6.3, 2.8, 5.1, 1.5],
        [6.1, 2.6, 5.6, 1.4],
        [7.7, 3. , 6.1, 2.3],
        [6.3, 3.4, 5.6, 2.4],
        [6.4, 3.1, 5.5, 1.8],
        [6. , 3. , 4.8, 1.8],
        [6.9, 3.1, 5.4, 2.1],
        [6.7, 3.1, 5.6, 2.4],
        [6.9, 3.1, 5.1, 2.3],
        [5.8, 2.7, 5.1, 1.9],
        [6.8, 3.2, 5.9, 2.3],
        [6.7, 3.3, 5.7, 2.5],
        [6.7, 3. , 5.2, 2.3],
        [6.3, 2.5, 5. , 1.9],
        [6.5, 3. , 5.2, 2. ],
        [6.2, 3.4, 5.4, 2.3],
        [5.9, 3. , 5.1, 1.8]]),
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'frame': None,
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...',
 'feature_names': ['sepal length (cm)',
  'sepal width (cm)',
  'petal length (cm)',
  'petal width (cm)'],
 'filename': 'iris.csv',
 'data_module': 'sklearn.datasets.data'}
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df["target"]=iris.target
df.columns=iris.feature_names+["target"]
df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

df.head()

sepal length (cm)
sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0

选择长和宽的数据进行可视化

#选取前100行数据进行可视化
plt.figure(figsize=(12, 8))
plt.scatter(df[:50]["sepal length (cm)"], df[:50]["sepal width (cm)"], label='0')
plt.scatter(df[50:100]["sepal length (cm)"], df[50:100]["sepal width (cm)"], label='1')
plt.xlabel('sepal length', fontsize=18)
plt.ylabel('sepal width', fontsize=18)
plt.legend()
plt.show()

2 划分训练数据和测试数据

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(df.iloc[:100,:2].values,df.iloc[:100,-1].values)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))
X_train,y_train
(array([[5. , 3.3],
        [4.6, 3.4],
        [5.2, 4.1],
        [5.7, 2.8],
        [5.1, 3.4],
        [4.8, 3. ],
        [5.9, 3.2],
        [5.7, 3.8],
        [4.8, 3.4],
        [5.3, 3.7],
        [5.1, 3.8],
        [5.5, 2.4],
        [6. , 2.2],
        [5.5, 4.2],
        [5.5, 2.6],
        [5.4, 3.4],
        [4.4, 2.9],
        [6. , 2.9],
        [5.8, 2.7],
        [4.4, 3.2],
        [5.6, 2.9],
        [5.8, 2.7],
        [6.7, 3.1],
        [6. , 2.7],
        [5.7, 2.9],
        [4.6, 3.2],
        [4.9, 3.1],
        [7. , 3.2],
        [4.7, 3.2],
        [5.1, 2.5],
        [6.3, 2.3],
        [4.6, 3.1],
        [6.4, 3.2],
        [6.6, 3. ],
        [4.6, 3.6],
        [5.5, 2.4],
        [5.6, 3. ],
        [5.1, 3.7],
        [6.1, 2.8],
        [5.6, 2.7],
        [4.8, 3.1],
        [4.8, 3. ],
        [5. , 3.5],
        [6.2, 2.2],
        [6. , 3.4],
        [5.1, 3.3],
        [5.4, 3.9],
        [5.7, 2.6],
        [6.7, 3.1],
        [4.5, 2.3],
        [4.8, 3.4],
        [4.9, 2.4],
        [5.8, 4. ],
        [5. , 3. ],
        [6.6, 2.9],
        [6.1, 2.9],
        [5. , 3.5],
        [6.8, 2.8],
        [5. , 2.3],
        [5.4, 3. ],
        [4.3, 3. ],
        [4.9, 3.1],
        [4.9, 3. ],
        [5.1, 3.8],
        [5.1, 3.5],
        [5.5, 2.5],
        [5. , 3.6],
        [5. , 3.4],
        [5.4, 3.9],
        [5.1, 3.8],
        [5.1, 3.5],
        [5.2, 3.5],
        [5.8, 2.6],
        [6.4, 2.9],
        [6.1, 2.8]]),
 array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
        1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1,
        1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 1, 1, 1]))

3 通过K个近邻预测的标签的距离来预测当前样本的标签

#定义邻居数量
from collections import Counter
K=3
KNN_x=[]
for i in range(X_train.shape[0]):
    if len(KNN_x)<K:
        KNN_x.append((euclidean(X_test[0],X_train[i]),y_train[i]))
KNN_x
[(0.6324555320336757, 0), (0.9219544457292889, 0), (1.3999999999999995, 0)]
count=Counter([item[1] for item in KNN_x])
count
Counter({0: 3})
count.items()
dict_items([(0, 3)])
sorted(count.items(),key=lambda x:x[1])[-1][0]
0
#返回任意一个样本x的标签
def calcu_distance_return(x,X_train,y_train):
    KNN_x=[]
    #遍历训练集中的每个样本
    for i in range(X_train.shape[0]):
        if len(KNN_x)<K:
            KNN_x.append((euclidean(x,X_train[i]),y_train[i]))
        else:
            KNN_x.sort()
            for j in range(K): 
                if (euclidean(x,X_train[i]))< KNN_x[j][0]:
                    KNN_x[j]=(euclidean(x,X_train[i]),y_train[i])
                    break
    knn_label=[item[1] for item in KNN_x]           
    counter_knn=Counter(knn_label) 
    return sorted(counter_knn.items(),key=lambda item:item[1])[-1][0]                  
#对整个测试集进行预测
def predict(X_test):
    y_pred=np.zeros(X_test.shape[0])
    for i in range(X_test.shape[0]):
        y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) 
        y_pred[i]=y_hat_i
    return y_pred

4 计算准确率

#输出预测结果
y_pred= predict(X_test).astype("int32")
y_pred
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 0])
y_test=y_test.astype("int32")
y_test
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 0])
#计算准确率
np.sum(y_pred==y_test)/X_test.shape[0]
1.0


目录
相关文章
|
8天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
29 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
29天前
|
机器学习/深度学习 算法 Java
机器学习、基础算法、python常见面试题必知必答系列大全:(面试问题持续更新)
机器学习、基础算法、python常见面试题必知必答系列大全:(面试问题持续更新)
|
1月前
|
机器学习/深度学习 人工智能 算法
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
玉米病害识别系统,本系统使用Python作为主要开发语言,通过收集了8种常见的玉米叶部病害图片数据集('矮花叶病', '健康', '灰斑病一般', '灰斑病严重', '锈病一般', '锈病严重', '叶斑病一般', '叶斑病严重'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。再使用Django搭建Web网页操作平台,实现用户上传一张玉米病害图片识别其名称。
56 0
【玉米病害识别】Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
|
1月前
|
机器学习/深度学习 算法 Python
探索机器学习中的决策树算法:从理论到实践
【10月更文挑战第5天】本文旨在通过浅显易懂的语言,带领读者了解并实现一个基础的决策树模型。我们将从决策树的基本概念出发,逐步深入其构建过程,包括特征选择、树的生成与剪枝等关键技术点,并以一个简单的例子演示如何用Python代码实现一个决策树分类器。文章不仅注重理论阐述,更侧重于实际操作,以期帮助初学者快速入门并在真实数据上应用这一算法。
|
18天前
|
机器学习/深度学习 人工智能 算法
探索机器学习中的决策树算法
【10月更文挑战第29天】本文将深入浅出地介绍决策树算法,一种在机器学习中广泛使用的分类和回归方法。我们将从基础概念出发,逐步深入到算法的实际应用,最后通过一个代码示例来直观展示如何利用决策树解决实际问题。无论你是机器学习的初学者还是希望深化理解的开发者,这篇文章都将为你提供有价值的见解和指导。
|
1月前
|
机器学习/深度学习 算法 数据处理
EM算法对人脸数据降维(机器学习作业06)
本文介绍了使用EM算法对人脸数据进行降维的机器学习作业。首先通过加载ORL人脸数据库,然后分别应用SVD_PCA、MLE_PCA及EM_PCA三种方法实现数据降维,并输出降维后的数据形状。此作业展示了不同PCA变种在人脸数据处理中的应用效果。
34 0
|
1月前
|
机器学习/深度学习 算法 搜索推荐
从理论到实践,Python算法复杂度分析一站式教程,助你轻松驾驭大数据挑战!
【10月更文挑战第4天】在大数据时代,算法效率至关重要。本文从理论入手,介绍时间复杂度和空间复杂度两个核心概念,并通过冒泡排序和快速排序的Python实现详细分析其复杂度。冒泡排序的时间复杂度为O(n^2),空间复杂度为O(1);快速排序平均时间复杂度为O(n log n),空间复杂度为O(log n)。文章还介绍了算法选择、分而治之及空间换时间等优化策略,帮助你在大数据挑战中游刃有余。
60 4
|
4月前
|
机器学习/深度学习 算法 搜索推荐
从理论到实践,Python算法复杂度分析一站式教程,助你轻松驾驭大数据挑战!
【7月更文挑战第22天】在大数据领域,Python算法效率至关重要。本文深入解析时间与空间复杂度,用大O表示法衡量执行时间和存储需求。通过冒泡排序(O(n^2)时间,O(1)空间)与快速排序(平均O(n log n)时间,O(log n)空间)实例,展示Python代码实现与复杂度分析。策略包括算法适配、分治法应用及空间换取时间优化。掌握这些,可提升大数据处理能力,持续学习实践是关键。
125 1
|
5月前
|
存储 机器学习/深度学习 算法
Python算法基础教程
Python算法基础教程
31 0
|
数据采集 SQL 算法
C++、Python、数据结构与算法、计算机基础、数据库教程汇总!
C++、Python、数据结构与算法、计算机基础、数据库教程汇总!
220 0
C++、Python、数据结构与算法、计算机基础、数据库教程汇总!
下一篇
无影云桌面