2.4.2 莺尾花数据集--kNN分类
Step1: 库函数导入
1importnumpyasnp# 加载莺尾花数据集fromsklearnimportdatasets# 导入KNN分类器fromsklearn.neighborsimportKNeighborsClassifierfromsklearn.model_selectionimporttrain_test_split
Step2: 数据导入&分析
1
# 导入莺尾花数据集iris=datasets.load_iris() X=iris.datay=iris.target# 得到训练集合和验证集合, 8: 2X_train, X_test, y_train, y_test=train_test_split(X, y, test_size=0.2)
Step3: 模型训练
这里我们设置参数k(n_neighbors)=5, 使用欧式距离(metric=minkowski & p=2)
1# 训练模型clf=KNeighborsClassifier(n_neighbors=5, p=2, metric="minkowski") clf.fit(X_train, y_train)
[3]:
KNeighborsClassifier()
KNeighborsClassifier
n_neighbors : int,optional(default = 5)
默认情况下kneighbors查询使用的邻居数。就是k-NN的k的值,选取最近的k个点。
weights : str或callable,可选(默认=‘uniform’)
默认是uniform,参数可以是uniform、distance,也可以是用户自己定义的函数。uniform是均等的权重,就说所有的邻近点的权重都是相等的。distance是不均等的权重,距离近的点比距离远的点的影响大。用户自定义的函数,接收距离的数组,返回一组维数相同的权重。
algorithm : {‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选
快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。
leaf_size : int,optional(默认值= 30)
默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。
p : 整数,可选(默认= 2)
距离度量公式。在上小结,我们使用欧氏距离公式进行距离度量。除此之外,还有其他的度量方法,例如曼哈顿距离。这个参数默认为2,也就是默认使用欧式距离公式进行距离度量。也可以设置为1,使用曼哈顿距离公式进行距离度量。
metric : 字符串或可调用,默认为’minkowski’
用于距离度量,默认度量是minkowski,也就是p=2的欧氏距离(欧几里德度量)。
metric_params : dict,optional(默认=None)
距离公式的其他关键参数,这个可以不管,使用默认的None即可。
n_jobs : int或None,可选(默认=None)
并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。
Step4:模型预测&可视化
1# 预测X_pred=clf.predict(X_test) acc=sum(X_pred==y_test) /X_pred.shape[0] print("预测的准确率ACC: %.3f"%acc)
预测的准确率ACC: 0.933
我们用表格来看一下KNN的训练和预测过程。这里用表格进行可视化:
- 训练数据[表格对应list]
feat_1 |
feat_2 |
feat_3 |
feat_4 |
label |
5.1 |
3.5 |
1.4 |
0.2 |
0 |
4.9 |
3. |
1.4 |
0.2 |
0 |
4.7 |
3.2 |
1.3 |
0.2 |
0 |
4.6 |
3.1 |
1.5 |
0.2 |
0 |
6.4 |
3.2 |
4.5 |
1.5 |
1 |
6.9 |
3.1 |
4.9 |
1.5 |
1 |
5.5 |
2.3 |
4. |
1.3 |
1 |
6.5 |
2.8 |
4.6 |
1.5 |
1 |
5.8 |
2.7 |
5.1 |
1.9 |
2 |
7.1 |
3. |
5.9 |
2.1 |
2 |
6.3 |
2.9 |
5.6 |
1.8 |
2 |
6.5 |
3. |
5.8 |
2.2 |
2 |
- knn.fit(X, y)的过程可以简单认为是表格存储
feat_1 |
feat_2 |
feat_3 |
feat_4 |
label |
5.1 |
3.5 |
1.4 |
0.2 |
0 |
4.9 |
3. |
1.4 |
0.2 |
0 |
4.7 |
3.2 |
1.3 |
0.2 |
0 |
4.6 |
3.1 |
1.5 |
0.2 |
0 |
6.4 |
3.2 |
4.5 |
1.5 |
1 |
6.9 |
3.1 |
4.9 |
1.5 |
1 |
5.5 |
2.3 |
4. |
1.3 |
1 |
6.5 |
2.8 |
4.6 |
1.5 |
1 |
5.8 |
2.7 |
5.1 |
1.9 |
2 |
7.1 |
3. |
5.9 |
2.1 |
2 |
6.3 |
2.9 |
5.6 |
1.8 |
2 |
6.5 |
3. |
5.8 |
2.2 |
2 |
- knn.predict(x)预测过程会计算x和所有训练数据的距离 这里我们使用欧式距离进行计算, 预测过程如下
step1: 计算x和所有训练数据的距离
feat_1 |
feat_2 |
feat_3 |
feat_4 |
距离 |
label |
5.1 |
3.5 |
1.4 |
0.2 |
0.14142136 |
0 |
4.9 |
3. |
1.4 |
0.2 |
0.60827625 |
0 |
4.7 |
3.2 |
1.3 |
0.2 |
0.50990195 |
0 |
4.6 |
3.1 |
1.5 |
0.2 |
0.64807407 |
0 |
6.4 |
3.2 |
4.5 |
1.5 |
3.66333182 |
1 |
6.9 |
3.1 |
4.9 |
1.5 |
4.21900462 |
1 |
5.5 |
2.3 |
4. |
1.3 |
3.14801525 |
1 |
6.5 |
2.8 |
4.6 |
1.5 |
3.84967531 |
1 |
5.8 |
2.7 |
5.1 |
1.9 |
4.24617475 |
2 |
7.1 |
3. |
5.9 |
2.1 |
5.35070089 |
2 |
6.3 |
2.9 |
5.6 |
1.8 |
4.73075047 |
2 |
6.5 |
3. |
5.8 |
2.2 |
5.09607692 |
2 |
step2: 根据距离进行编号排序
距离升序编号 |
feat_1 |
feat_2 |
feat_3 |
feat_4 |
距离 |
label |
1 |
5.1 |
3.5 |
1.4 |
0.2 |
0.14142136 |
0 |
3 |
4.9 |
3. |
1.4 |
0.2 |
0.60827625 |
0 |
2 |
4.7 |
3.2 |
1.3 |
0.2 |
0.50990195 |
0 |
4 |
4.6 |
3.1 |
1.5 |
0.2 |
0.64807407 |
0 |
6 |
6.4 |
3.2 |
4.5 |
1.5 |
3.66333182 |
1 |
8 |
6.9 |
3.1 |
4.9 |
1.5 |
4.21900462 |
1 |
5 |
5.5 |
2.3 |
4. |
1.3 |
3.14801525 |
1 |
7 |
6.5 |
2.8 |
4.6 |
1.5 |
3.84967531 |
1 |
9 |
5.8 |
2.7 |
5.1 |
1.9 |
4.24617475 |
2 |
12 |
7.1 |
3. |
5.9 |
2.1 |
5.35070089 |
2 |
10 |
6.3 |
2.9 |
5.6 |
1.8 |
4.73075047 |
2 |
11 |
6.5 |
3. |
5.8 |
2.2 |
5.09607692 |
2 |
step3: 我们设置k=5,选择距离最近的k个样本进行投票
距离升序编号 |
feat_1 |
feat_2 |
feat_3 |
feat_4 |
距离 |
label |
1 |
5.1 |
3.5 |
1.4 |
0.2 |
0.14142136 |
0 |
3 |
4.9 |
3. |
1.4 |
0.2 |
0.60827625 |
0 |
2 |
4.7 |
3.2 |
1.3 |
0.2 |
0.50990195 |
0 |
4 |
4.6 |
3.1 |
1.5 |
0.2 |
0.64807407 |
0 |
6 |
6.4 |
3.2 |
4.5 |
1.5 |
3.66333182 |
1 |
8 |
6.9 |
3.1 |
4.9 |
1.5 |
4.21900462 |
1 |
5 |
5.5 |
2.3 |
4. |
1.3 |
3.14801525 |
1 |
7 |
6.5 |
2.8 |
4.6 |
1.5 |
3.84967531 |
1 |
9 |
5.8 |
2.7 |
5.1 |
1.9 |
4.24617475 |
2 |
12 |
7.1 |
3. |
5.9 |
2.1 |
5.35070089 |
2 |
10 |
6.3 |
2.9 |
5.6 |
1.8 |
4.73075047 |
2 |
11 |
6.5 |
3. |
5.8 |
2.2 |
5.09607692 |
2 |
step4: k近邻的label进行投票
nn_labels = [0, 0, 0, 0, 1] --> 得到最后的结果0。