机器学习笔记之KNN分类

简介:

KNN分类器作为有监督学习中较为通俗易懂的分类算法,在各类分类任务中经常使用。

KNN模型的核心思想很简单,即近朱者赤、近墨者黑,它通过将每一个测试集样本点与训练集中每一个样本之间测算欧氏距离,然后取欧氏距离最近的K个点(k是可以人为划定的近邻取舍个数,K的确定会影响算法结果),并统计这K个训练集样本点所属类别频数,将其中频数最高的所属类别化为该测试样本点的预测类别。

1972d2e799cbc1ba802b6eda6ce50e6d306b73eb

这样意味着测试集中的每一个点都需要与训练集每一个样本点之间计算一次欧氏距离,算法复杂度较高。

其伪代码如下:

d47e62d2b349aca45e42305ed6714efbe5ed61d9 计算已知类别数据集中的点与当前点之间的距离;
d47e62d2b349aca45e42305ed6714efbe5ed61d9 按照距离递增次序排序;
d47e62d2b349aca45e42305ed6714efbe5ed61d9 选择与当前距离最小的k个点;
d47e62d2b349aca45e42305ed6714efbe5ed61d9 确定前k个点所在类别的出现概率
d47e62d2b349aca45e42305ed6714efbe5ed61d9 返回前k个点出现频率最高的类别作为当前点的预测分类。

其优点主要体现在简单易懂,无需训练;
但其数据结果对训练样本中的类别分布状况很敏感,类别分布不平衡会影响分类结果;
对设定的k值(选取的近邻个数)也会影响最终划分的类别;
随着训练集与测试集的增加,算法复杂度较高,内存占用高。

KNN算法中的距离度量既可以采用欧式距离,也可以采用余弦距离(文本分类任务),欧氏距离会受到特征量级大小的影响,因而需要在训练前进行数据标准化。

本次练习使用莺尾花数据集(数据比较规范、量级小适合单机训练)。

R Code:


## !/user/bin/env RStudio 1.1.423
## -*- coding: utf-8 -*-
## KNN Model

library("dplyr")
library('caret')
rm(list = ls())
gc()

#数据转换(数据导入、数据标准化、测试集与训练集分割、样本与标签分配)

Data_Input <- function(file_path = "D:/R/File/iris.csv",p = .75){
data = read.csv(file_path,stringsAsFactors = FALSE,check.names = FALSE)
names(data) <- c('sepal_length','sepal_width','petal_length','petal_width','class')
data[,-ncol(data)] <- scale(data[,-ncol(data)])
data['class_c'] = as.numeric(as.factor(data$class))
x = data[,1:(ncol(data)-2)];y = data$class_c
samples = sample(nrow(data),p*nrow(data))
train_data = x[samples,1:(ncol(data)-2)];train_target = y[samples]
test_data = data[-samples,1:(ncol(data)-2)];test_target = y[-samples]
return(
list(
data = data,
train_data = train_data,
test_data = test_data,
train_target = train_target,
test_target = test_target ) )
}
# 分类器构建(距离计算、排序、统计类别频数、输出最高频类别作为预测类):

kNN_Classify <-function(test_data,test_target,train_data,train_target,k){
# step 1: 计算距离
centr_matrix = unlist(rep(test_data,time = nrow(train_data)),use.names = FALSE) %>%
matrix(byrow = TRUE,ncol = 4)
diff = as.matrix(train_data) - centr_matrix
squaredDist = apply(diff^2,1,sum)
distance = as.numeric(squaredDist ^ 0.5)
# step 2: 对距离排序
sortedDistIndices = rank(distance)
classCount = c()
for (i in 1:k){
# step 3: 选择k个最近邻居
target_sort = train_target[sortedDistIndices == i]
classCount = c(classCount,target_sort)
}
# step 4: 分类统计并返回频数最高的类
Max_count = plyr::count(classCount) %>% arrange(-freq) %>%.[1,1]
return (Max_count)
}
data_source <- Data_Input()
train_data <- data_source$train_data
test_data <- data_source$test_data
train_target <- data_source$train_target
test_target <- data_source$test_target
# 测试单样本分类

kNN_Classify(
test_data = test_data[1,] ,
test_target = test_target,
train_data = train_data,
train_target = train_target,
k = 5
)

# 构建全样本分类任务(全样本扫描)、输出混洗矩阵与预测类别结果

datingClassTest <- function(test_data,train_data,train_target,test_target,k = 5){
m = nrow(test_data)
w = ncol(test_data)
errorCount = 0.0
test_predict = c()
for (i in 1:m){
classifierResult = kNN_Classify(
test_data = test_data[i,],
train_target = train_target,
train_data = train_data, k = k
)
if (classifierResult != test_target[i]){
errorCount = errorCount + 1.0
}
test_predict = c(test_predict,classifierResult)
}
test_data[['test_predict']] = test_predict
test_data[['class']] = test_target
target_names = c('setosa', 'versicolor', 'virginica')
print(confusionMatrix(factor(test_predict,labels = target_names),factor(test_target,labels = target_names)))
confusion_matrix = table(test_predict,test_target)
cat(sprintf("in datingClassTest,the total error rate is: %f",errorCount/m),sep = '\n')
dimnames(confusion_matrix) <- list(target_names,target_names)
cat(sprintf('in datingClassTest,errorCount:%d',errorCount),sep = '\n')
return (list(test_data = test_data,confusion_matrix = confusion_matrix))
}
#执行分类任务
result <- datingClassTest(
test_data = test_data,
train_target = train_target,
train_data = train_data,
result$test_data
test_target = test_target ) 预测结果收集与混洗矩阵输出:
result$confusion_matrix
11702f13bd386401dc1cf617d512e4ab09a208f3

从结果来看,整体样本划分准确率为92.1%,一共错判了三个点,错误率为7.89%,考虑到数据集随机划分导致的样本类别平衡问题,每次分类结果都可能不一致(可通过设置随机种子来复现抽样结果),这里的K值确定需要根据实际交叉验证情况进行择优取舍。

Python:


#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import time
import csv
from numpy import tile
from sklearn import preprocessing
from collections import Counter
from collections import OrderedDict
import pandas as pd
from sklearn import neighbors
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from time import time

'''KNN分类器'''

## 数据集读入、训练集与测试集样本划分及数据标准化:

def Data_Input():
data = pd.read_csv("D:/Python/File/iris.csv")
data.columns = ['sepal_length','sepal_width','petal_length','petal_width','class']
data.iloc[:,0:-1] = preprocessing.scale(data.iloc[:,0:-1])
data['class_c'] = pd.factorize(data['class'])[0]
x,y = data.iloc[:,0:-2],data.iloc[:,-1]
print(data.shape,'\n',data.head())
train_data,test_data,train_target,test_target = train_test_split(x,y,test_size = 0.25)
return train_data,test_data,train_target,test_target
# KNN分类器函数
def kNN_Classify(test_data,train_data,train_target,k):
#diff = train_data - np.tile(test_data,(len(train_data),1)) #test_data = test_data.values[0].reshape(1,4) # 数据标准化 # step 1: 计算距离
diff = train_data - np.repeat(test_data,repeats = len(train_data) ,axis = 0)
squaredDist = np.sum(diff ** 2, axis = 1)
distance = squaredDist ** 0.5 # step 2: 对距离排序
sortedDistIndices = np.argsort(distance).values
classCount = []
for i in range(k):
# step 3: 选择k个最近邻
target_sort = train_target.values[sortedDistIndices[i]]
classCount.append(target_sort)
# step 4: 计算k个最近邻中各类别出现的次数
counter = Counter(classCount)
# step 5: 返回出现次数最多的类别标签
Max_count = counter.most_common(1)[0][0]
return Max_count
#单样本测试:
kNN_Classify(test_data.values[0].reshape(1,4),train_data,train_target,k = 5)

#构建全样本扫描的分类器并输出分类结果与混洗矩阵:

def datingClassTest(test_data,train_data,train_target,test_target,k = 5):
m = test_data.shape[0]
w = test_data.shape[1]
errorCount = 0.0
test_predict = []
for i in range(m):
classifierResult = kNN_Classify(
test_data = test_data.values[i].reshape(1,w),
train_data = train_data,
train_target = train_target,
k = k
)
if (classifierResult != test_target.values[i]):
errorCount += 1.0
test_predict.append(classifierResult)
test_data['test_predict'] = test_predict
test_data['class'] = test_target
confusion_matrix = pd.crosstab(train_target,test_predict)
print ("in datingClassTest,the total error rate is: %f" % (errorCount/float(m)))
print ('in datingClassTest,errorCount:',errorCount)
target_names = ['setosa', 'versicolor', 'virginica']
print(classification_report(test_target,test_predict, target_names=target_names))
return test_data,confusion_matrix
#执行分类任务并输出分类结果到本地:

if __name__ == "__main__":
#计时开始:
t0 = time.time()
train_data,test_data,train_target,test_target = Data_Input()
test_reslut,confusion_matrix = datingClassTest(test_data,train_data,train_target,test_target,k = 5)
name = "KNN" + str(int(time.time())) + ".csv" print ("Generating results file:", name)
with open("D:/Python/File/" + name, "w",newline='') as csvfile:
open_file_object = csv.writer(csvfile)
open_file_object.writerow(['sepal_length','sepal_width','petal_length','petal_width','test_predict','class'])
open_file_object.writerows(test_reslut.values)
t1 = time.time() total = t1 - t0
print("消耗时间:{}".format(total))
9f0d75d85f343bfa7c7e045bdcee94ff7b3aa79a

这只是第一次尝试手写KNN,还没有做很好地代码封装和模型调优,作为代码实战的一个小开端,之后会更加注重特征选择和模型优化方面的学习~


原文发布时间为:2018-06-22

本文作者:杜雨

本文来自云栖社区合作伙伴“数据小魔方”,了解相关信息可以关注“数据小魔方”。

相关文章
|
22天前
|
机器学习/深度学习 人工智能 自然语言处理
机器学习之线性回归与逻辑回归【完整房价预测和鸢尾花分类代码解释】
机器学习之线性回归与逻辑回归【完整房价预测和鸢尾花分类代码解释】
|
1月前
|
机器学习/深度学习 数据可视化 算法
机器学习中的分类问题:如何选择和理解性能衡量标准
机器学习中的分类问题:如何选择和理解性能衡量标准
机器学习中的分类问题:如何选择和理解性能衡量标准
|
15天前
|
机器学习/深度学习 自然语言处理 算法
|
8天前
|
机器学习/深度学习 存储 算法
PYTHON集成机器学习:用ADABOOST、决策树、逻辑回归集成模型分类和回归和网格搜索超参数优化
PYTHON集成机器学习:用ADABOOST、决策树、逻辑回归集成模型分类和回归和网格搜索超参数优化
30 7
|
10天前
|
索引 机器学习/深度学习 Python
fast.ai 机器学习笔记(二)(3)
fast.ai 机器学习笔记(二)
24 0
fast.ai 机器学习笔记(二)(3)
|
10天前
|
机器学习/深度学习 算法框架/工具 PyTorch
fast.ai 机器学习笔记(三)(2)
fast.ai 机器学习笔记(三)
37 0
fast.ai 机器学习笔记(三)(2)
|
机器学习/深度学习 算法 计算机视觉
fast.ai 机器学习笔记(四)(4)
fast.ai 机器学习笔记(四)
18 0
fast.ai 机器学习笔记(四)(4)
|
10天前
|
机器学习/深度学习 索引 Python
fast.ai 机器学习笔记(四)(2)
fast.ai 机器学习笔记(四)
98 0
fast.ai 机器学习笔记(四)(2)
|
10天前
|
机器学习/深度学习 数据挖掘 Python
fast.ai 机器学习笔记(一)(4)
fast.ai 机器学习笔记(一)
75 1
fast.ai 机器学习笔记(一)(4)
|
10天前
|
机器学习/深度学习 Python 索引
fast.ai 机器学习笔记(一)(1)
fast.ai 机器学习笔记(一)
35 0
fast.ai 机器学习笔记(一)(1)