kNN算法将找出k个距离最近的邻居作为目标的同一类别。
1.图解kNN算法
使用OpenCV的ml模块中的kNN算法的基本步骤如下。
(1)调用cv2.ml.KNearest_create()函数创建kNN分类器。
(2)将训练数据和标志作为输入,调用kNN分类器的train()方法训练模型。
(3)将待分类数据作为输入,调用kNN分类器的findNearest()方法找出k个最近邻居,返回分类结果的相关信息。
下面的代码在图像中随机选择20个点,为每个点随机分配标志(0或1);图像中用矩形表示标志0,用三角形表示标志1;再随机新增一个点,用kNN算法找出其邻居,并确定其标志(即完成分类)。
图解kNN算法
import cv2
import numpy as np
import matplotlib.pyplot as plt
points = np.random.randint(0,100,(20,2)) #随机选择20个点
labels = np.random.randint(0,2,(20,1)) #为随机点随机分配标志
label0s = points[labels.ravel()==0] #分出标志为0的点
plt.scatter(label0s[:,0],label0s[:,1],80,'b','s') #将标志为0的点绘制为蓝色矩形
label1s = points[labels.ravel()==1] #分出标志为1的点
plt.scatter(label1s[:,0],label1s[:,1],80,'r','^') #将标志为1的点绘制为红色三角形
newpoint = np.random.randint(0,100,(1,2)) #随机选择一个点,下面确定其分类
plt.scatter(newpoint[:,0],newpoint[:,1],80,'g','o') #将待分类新点绘制为绿色圆点
plt.show()
进一步使用kNN算法确认待分类新点的类别、3个最近邻居和距离
knn = cv2.ml.KNearest_create() #创建kNN分类器
knn.train(points.astype(np.float32), cv2.ml.ROW_SAMPLE,
labels.astype(np.float32)) #训练模型
ret,results,neighbours,dist = knn.findNearest(
newpoint.astype(np.float32), 3) #找出3个最近邻居
print( "新点标志: %s" % results)
print( "邻居: %s" % neighbours)
print( "距离:%s" % dist)
用kNN算法实现手写数字识别
OpenCV源代码中的“samples\data”文件夹下的digits.png文件是一个手写数字图像,如图10-2所示。
digits.png的大小为2000×1000,其中每个数字的大小为20×20,每个数字的样本有500个(5行、100列),共有5000个数字样本。可使用这些数字图像来训练kNN模型和执行测试。
用kNN算法实现手写识别
import cv2
import numpy as np
import matplotlib.pyplot as plt
gray = cv2.imread('digits.png',0) #读入手写数字的灰度图像
digits = [np.hsplit(r,100) for r in np.vsplit(gray,50)] #分解数字:50行、100列
np_digits = np.array(digits) #转换为NumPy数组
准备训练数据,转换为二维数组,每个图像400个像素
train_data = np_digits.reshape(-1,400).astype(np.float32)
train_labels = np.repeat(np.arange(10),500)[:,np.newaxis] #定义标志
knn = cv2.ml.KNearest_create() #创建kNN分类器
knn.train(train_data, cv2.ml.ROW_SAMPLE, train_labels) #训练模型
用绘图工具创建的手写数字5图像(大小为20×20)进行测试
test= cv2.imread('d5.jpg',0) #打开图像
test_data=test.reshape(1,400).astype(np.float32) #转换为测试数据
ret,result,neighbours,dist = knn.findNearest(test_data,k=3) #执行测试
print(result.ravel()) #输出测试结果
print(neighbours.ravel())
将对手写数字9拍摄所得图像的大小转换为20×20进行测试
img2=cv2.imread('d9.jpg',0)
ret,img2=cv2.threshold(img2,150,255,cv2.THRESH_BINARY_INV) #反二值化阈值处理
test_data=img2.reshape(1,400).astype(np.float32) #转换为测试数据
ret,result,neighbours,dist = knn.findNearest(test_data,k=3) #执行测试
print(result.ravel()) #输出测试结果
print(neighbours.ravel())