K Nearest Neighbors
这个算法首先贮藏所有的训练样本,然后通过分析(包括选举,计算加权和等方式)一个新样本周围K个最近邻以给出该样本的相应值。这种方法有时候被称作“基于样本的学习”,即为了预测,我们对于给定的输入搜索最近的已知其相应的特征向量。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
class
CvKNearest :
public
CvStatModel
//继承自ML库中的统计模型基类
{
public
:
CvKNearest();
//无参构造函数
virtual
~CvKNearest();
//虚函数定义
CvKNearest(
const
CvMat* _train_data,
const
CvMat* _responses,
const
CvMat* _sample_idx=0,
bool
_is_regression=
false
,
int
max_k=32 );
//有参构造函数
virtual
bool
train(
const
CvMat* _train_data,
const
CvMat* _responses,
const
CvMat* _sample_idx=0,
bool
is_regression=
false
,
int
_max_k=32,
bool
_update_base=
false
);
virtual
float
find_nearest(
const
CvMat* _samples,
int
k, CvMat* results,
const
float
** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 )
const
;
virtual
void
clear();
int
get_max_k()
const
;
int
get_var_count()
const
;
int
get_sample_count()
const
;
bool
is_regression()
const
;
protected
:
...
};
|
CvKNearest::train
训练KNN模型
bool
CvKNearest::train(
const
CvMat* _train_data,
const
CvMat* _responses,
const
CvMat* _sample_idx=0,
bool
is_regression=
false
,
int
_max_k=32,
bool
_update_base=
false
);
|
这个类的方法训练K近邻模型。它遵循一个一般训练方法约定的限制:只支持CV_ROW_SAMPLE数据格式,输入向量必须都是有序的,而输出可以 是 无序的(当is_regression=false),可以是有序的(is_regression=true)。并且变量子集和省略度量是不被支持的。
参数_max_k 指定了最大邻居的个数,它将被传给方法find_nearest。 参数 _update_base 指定模型是由原来的数据训练(_update_base=false),还是被新训练数据更新后再训练(_update_base=true)。在后一种情况下_max_k 不能大于原值, 否则它会被忽略.
CvKNearest::find_nearest
寻找输入向量的最近邻
float
CvKNearest::find_nearest(
const
CvMat* _samples,
int
k, CvMat* results=0,
const
float
** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 )
const
;
|
对每个输入向量(表示为matrix_sample的每一行),该方法找到k(k≤get_max_k() )个最近邻。在回归中,预测结果将是指定向量的近邻的响应的均值。在分类中,类别将由投票决定。
对传统分类和回归预测来说,该方法可以有选择的返回近邻向量本身的指针(neighbors, array of k*_samples->rows pointers),它们相对应的输出值(neighbor_responses, a vector of k*_samples->rows elements) ,和输入向量与近邻之间的距离(dist, also a vector of k*_samples->rows elements)。
对每个输入向量来说,近邻将按照它们到该向量的距离排序。
对单个输入向量,所有的输出矩阵是可选的,而且预测值将由该方法返回。
例程:使用kNN进行2维样本集的分类,样本集的分布为混合高斯分布
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
#include "ml.h"
#include "highgui.h"
int
main(
int
argc,
char
** argv )
{
const
int
K = 10;
int
i, j, k, accuracy;
float
response;
int
train_sample_count = 100;
CvRNG rng_state = cvRNG(-1);
CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
float
_sample[2];
CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
cvZero( img );
CvMat trainData1, trainData2, trainClasses1, trainClasses2;
// form the training samples
cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
cvSet( &trainClasses1, cvScalar(1) );
cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
cvSet( &trainClasses2, cvScalar(2) );
// learn classifier
CvKNearest knn( trainData, trainClasses, 0,
false
, K );
CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
for
( i = 0; i < img->height; i++ )
{
for
( j = 0; j < img->width; j++ )
{
sample.data.fl[0] = (
float
)j;
sample.data.fl[1] = (
float
)i;
// estimates the response and get the neighbors' labels
response = knn.find_nearest(&sample,K,0,0,nearests,0);
// compute the number of neighbors representing the majority
for
( k = 0, accuracy = 0; k < K; k++ )
{
if
( nearests->data.fl[k] == response)
accuracy++;
}
// highlight the pixel depending on the accuracy (or confidence)
cvSet2D( img, i, j, response == 1 ?
(accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
(accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
}
}
// display the original training samples
for
( i = 0; i < train_sample_count/2; i++ )
{
CvPoint pt;
pt.x = cvRound(trainData1.data.fl[i*2]);
pt.y = cvRound(trainData1.data.fl[i*2+1]);
cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
pt.x = cvRound(trainData2.data.fl[i*2]);
pt.y = cvRound(trainData2.data.fl[i*2+1]);
cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
}
cvNamedWindow(
"classifier result"
, 1 );
cvShowImage(
"classifier result"
, img );
cvWaitKey(0);
cvReleaseMat( &trainClasses );
cvReleaseMat( &trainData );
return
0;
}
|
结果:
本文转自编程小翁博客园博客,原文链接:http://www.cnblogs.com/wengzilin/archive/2013/04/05/3001778.html,如需转载请自行联系原作者