KNN算法学习笔记

简介: KNN算法学习笔记

一、何为KNN?


KNN(K- Nearest Neighbor)法即K最邻近法,最初由 Cover和Hart于1968年提出,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路非常简单直观:如果一个样本在特征空间中的K个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别   。


该方法的不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最邻近点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。另外还有一种 Reverse KNN法,它能降低KNN算法的计算复杂度,提高分类的效率   。


KNN算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分

f232558ab1734f82ae22b8df8d81e6f9_c00e46f2c3154618814d1a616bf8a7b9.png

二、核心思想


KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合  


优缺点:


优点:KNN方法思路简单,易于理解,易于实现,无需估计参数

缺点:1.当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数;2.计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点


三、算法流程(实现步骤)


  1. 准备数据,对数据进行处理
  2. 计算测试样本点(也就是待分类点)得到其他每个样本点的距离

1.欧几里得距离

add8e312c71f7847601889dafcf42f03_39427c50d8ee4292ab4a140bb16a4da5.png

2.马氏距离

3d8724af46778ed8d276969e186d1ff4_7b9c7b8d496a440b9eecd058c25d8d96.png

3.升降排序(对每个距离进行排序)

4.取前k个(选择出距离最小的k个点)

  1. k取值太小:受个例影响比较严重,波动较大
  2. k取值太大:导致分类模糊

5.加权平均(对k个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在k个点中占比最高的哪一类)


四、实战案例


案例:将mnist数据集和fashion mnist数据集包括训练集和验证集导入到工程文件中,接着计算验证集和训练集的距离,并从小到达排序得到距离最近的k个邻居,并通过投票得到所属类别最高的类别,并判断该验证集的图片属于该类别,接着讲该类别的标签和验证集的标签进行比对,如果相符合则是正确的,如果不相符合,则是属于出错,最后输出计算出的错误率和准确率。

from numpy import *
import operator
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from os import listdir
from mpl_toolkits.mplot3d import Axes3D
import struct
#读取图片
def read_image(file_name):
    #先用二进制方式把文件都读进来
    file_handle=open(file_name,"rb")  #以二进制打开文档
    file_content=file_handle.read()   #读取到缓冲区中
    offset=0
    head = struct.unpack_from('>IIII', file_content, offset)  # 取前4个整数,返回一个元组
    offset += struct.calcsize('>IIII')
    imgNum = head[1]  #图片数
    rows = head[2]   #宽度
    cols = head[3]  #高度
    # print(imgNum)
    # print(rows)
    # print(cols)
    #测试读取一个图片是否读取成功
    #im = struct.unpack_from('>784B', file_content, offset)
    #offset += struct.calcsize('>784B')
    images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法
    image_size=rows*cols#单个图片的大小
    fmt='>' + str(image_size) + 'B'#单个图片的format
    for i in range(imgNum):
        images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
        # images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
        offset += struct.calcsize(fmt)
    return images
    '''bits = imgNum * rows * cols  # data一共有60000*28*28个像素值
    bitsString = '>' + str(bits) + 'B'  # fmt格式:'>47040000B'
    imgs = struct.unpack_from(bitsString, file_content, offset)  # 取data数据,返回一个元组
    imgs_array=np.array(imgs).reshape((imgNum,rows*cols))     #最后将读取的数据reshape成 【图片数,图片像素】二维数组
    return imgs_array'''
#读取标签
def read_label(file_name):
    file_handle = open(file_name, "rb")  # 以二进制打开文档
    file_content = file_handle.read()  # 读取到缓冲区中
    head = struct.unpack_from('>II', file_content, 0)  # 取前2个整数,返回一个元组
    offset = struct.calcsize('>II')
    labelNum = head[1]  # label数
    # print(labelNum)
    bitsString = '>' + str(labelNum) + 'B'  # fmt格式:'>47040000B'
    label = struct.unpack_from(bitsString, file_content, offset)  # 取data数据,返回一个元组
    return np.array(label)
#KNN算法
def KNN(test_data, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#dataSet.shape[0]表示的是读取矩阵第一维度的长度,代表行数
    # distance1 = tile(test_data, (dataSetSize,1)) - dataSet#欧氏距离计算开始
    # print("dataSetSize:")
    # print(dataSetSize)
    distance1 = tile(test_data, (dataSetSize)).reshape((60000,784))-dataSet#tile函数在行上重复dataSetSizec次,在列上重复1次
    # print("distance1.shape")
    # print(distance1.shape)
    distance2 = distance1**2 #每个元素平方
    distance3 = distance2.sum(axis=1)#矩阵每行相加
    distances4 = distance3**0.5#欧氏距离计算结束
    # print(distances4[53843])
    # print(distances4[38620])
    # print(distances4[16186])
    sortedDistIndicies = distances4.argsort() #返回从小到大排序的索引
    classCount=np.zeros((10), np.int32)#10是代表10个类别
    for i in range(k): #统计前k个数据类的数量
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] += 1
    max = 0
    id = 0
    print(classCount.shape[0])
    # print(classCount.shape[1])
    for i in range(classCount.shape[0]):
        if classCount[i] >= max:
            max = classCount[i]
            id = i
    print(id)
    # sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#从大到小按类别数目排序
    return id
def test_KNN():
    # 文件获取
    #mnist数据集
    # train_image = "F:\mnist\\train-images-idx3-ubyte"
    # test_image = "F:\mnist\\t10k-images-idx3-ubyte"
    # train_label = "F:\mnist\\train-labels-idx1-ubyte"
    # test_label = "F:\mnist\\t10k-labels-idx1-ubyte"
    #fashion mnist数据集
    train_image = "train-images-idx3-ubyte"
    test_image = "t10k-images-idx3-ubyte"
    train_label = "train-labels-idx1-ubyte"
    test_label = "t10k-labels-idx1-ubyte"
    # 读取数据
    train_x = read_image(train_image)  # train_dataSet
    test_x = read_image(test_image)  # test_dataSet
    train_y = read_label(train_label)  # train_label
    test_y = read_label(test_label)  # test_label
    # print(train_x.shape)
    # print(test_x.shape)
    # print(train_y.shape)
    # print(test_y.shape)
    # plt.imshow(train_x[0])
    # plt.show()
    testRatio = 1  # 取数据集的前0.1为测试数据,这个参数比重可以改变
    train_row = train_x.shape[0]  # 数据集的行数,即数据集的总的样本数
    test_row=test_x.shape[0]
    testNum = int(test_row * testRatio)
    errorCount = 0  # 判断错误的个数
    for i in range(testNum):
        result = KNN(test_x[i], train_x, train_y, 30)
        # print('返回的结果是: %s, 真实结果是: %s' % (result, train_y[i]))
        print(result, test_y[i])
        if result != test_y[i]:
            errorCount += 1.0# 如果mnist验证集的标签和本身标签不一样,则出错
    error_rate = errorCount / float(testNum)  # 计算出错率
    acc = 1.0 - error_rate
    print(errorCount)
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (error_rate))
    print("\nthe total accuracy rate is: %f" % (acc))
if __name__ == "__main__":
    test_KNN()#test()函数中调用了读取数据集的函数,并调用分类函数对数据集进行分类,最后对分类情况进行计算

结果分析:


输入:mnist数据集或者fashion mnist数据集


输出:出错率和准确率


Mnist数据集:


取k=30,验证集是50个的时候,准确率是1;


取k=30,验证集是500个的时候,准确率是0.98;


取k=30,验证集是10000个的时候,准确率是0.84。


Fashion Mnist数据集


K=30,验证集是10000的时候,一共的出错个数是1666,准确率是0.8334。


目录
相关文章
|
3月前
|
机器学习/深度学习 算法
机器学习入门(三):K近邻算法原理 | KNN算法原理
机器学习入门(三):K近邻算法原理 | KNN算法原理
|
3月前
|
机器学习/深度学习 算法 API
机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题
机器学习入门(五):KNN概述 | K 近邻算法 API,K值选择问题
|
4月前
|
算法 Python
KNN
【9月更文挑战第11天】
63 13
|
4月前
|
算法 大数据
K-最近邻(KNN)
K-最近邻(KNN)
|
4月前
|
机器学习/深度学习 算法 数据挖掘
R语言中的支持向量机(SVM)与K最近邻(KNN)算法实现与应用
【9月更文挑战第2天】无论是支持向量机还是K最近邻算法,都是机器学习中非常重要的分类算法。它们在R语言中的实现相对简单,但各有其优缺点和适用场景。在实际应用中,应根据数据的特性、任务的需求以及计算资源的限制来选择合适的算法。通过不断地实践和探索,我们可以更好地掌握这些算法并应用到实际的数据分析和机器学习任务中。
|
6月前
knn增强数据训练
【7月更文挑战第27天】
47 10
|
6月前
|
机器人 计算机视觉 Python
K-最近邻(KNN)分类器
【7月更文挑战第26天】
52 8
|
6月前
创建KNN类
【7月更文挑战第22天】创建KNN类。
39 8
|
6月前
knn增强数据训练
【7月更文挑战第28天】
51 2
|
6月前
|
机器学习/深度学习 数据采集 算法
Python实现PCA降维和KNN人脸识别模型(PCA和KNeighborsClassifier算法)项目实战
Python实现PCA降维和KNN人脸识别模型(PCA和KNeighborsClassifier算法)项目实战

热门文章

最新文章