In [34]: from sklearn.model_selection import GridSearchCV
knn=KNeighborsClassifier()
k_range=range(1,30)
param_grid=dict(n_neighbors=k_range)
grid=GridSearchCV(knn,param_grid,cv=5,scoring='accuracy')
grid.fit(my_data[['sepal_length','sepal_width']],my_class)
Out[34]: GridSearchCV(cv=5, error_score='raise-deprecating',
estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30,
metric='minkowski',
metric_params=None, n_jobs=None,
n_neighbors=5, p=2,
weights='uniform'),
iid='warn', n_jobs=None, param_grid={'n_neighbors': range(1, 30)},
pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
scoring='accuracy', verbose=0)
In [42]: pd.DataFrame(grid.cv_results_).head(29)
Out[42]: mean_fit_time std_fit_time mean_score_time std_score_time \
0 0.001196 4.007471e-04 0.001199 3.980875e-04
1 0.001396 4.974370e-04 0.001589 8.019150e-04
2 0.001600 8.093282e-04 0.000797 3.986157e-04
3 0.001784 3.930052e-04 0.001216 4.061510e-04
4 0.001203 4.268555e-04 0.001196 3.978265e-04
5 0.000999 2.887130e-06 0.000995 3.635908e-05
6 0.000985 2.348654e-05 0.000598 4.881907e-04
7 0.000997 2.081404e-05 0.001004 1.305838e-05
8 0.000997 2.431402e-07 0.000998 2.138815e-05
9 0.000798 3.992108e-04 0.000798 3.987560e-04
10 0.000990 1.368896e-05 0.001005 1.342667e-05
11 0.000798 3.991871e-04 0.000598 4.884969e-04
12 0.000998 1.723224e-06 0.000798 3.990436e-04
13 0.000997 1.144409e-06 0.000997 2.126630e-05
14 0.000997 6.217196e-07 0.000605 4.942987e-04
15 0.000998 9.608003e-07 0.000990 1.295875e-05
16 0.000798 3.987558e-04 0.000805 4.024717e-04
17 0.000990 1.435953e-05 0.000799 3.992572e-04
18 0.000805 4.028115e-04 0.000990 1.312126e-05
19 0.000997 2.048736e-06 0.000799 3.994484e-04
20 0.000998 2.183664e-05 0.000997 8.064048e-07
21 0.000798 3.988554e-04 0.000798 3.990891e-04
22 0.000997 5.091228e-07 0.000997 2.112017e-05
23 0.000798 3.988763e-04 0.001010 2.363564e-05
24 0.000997 3.021809e-06 0.000997 1.843085e-06
25 0.000982 4.103033e-05 0.001012 2.276996e-05
26 0.001005 1.443124e-05 0.000985 2.403921e-05
27 0.000997 3.173744e-06 0.000997 1.256174e-06
28 0.000791 3.956460e-04 0.001012 1.680579e-05
param_n_neighbors params split0_test_score \
0 1 {'n_neighbors': 1} 0.733333
1 2 {'n_neighbors': 2} 0.700000
2 3 {'n_neighbors': 3} 0.666667
3 4 {'n_neighbors': 4} 0.666667
4 5 {'n_neighbors': 5} 0.700000
5 6 {'n_neighbors': 6} 0.733333
6 7 {'n_neighbors': 7} 0.733333
7 8 {'n_neighbors': 8} 0.733333
8 9 {'n_neighbors': 9} 0.733333
9 10 {'n_neighbors': 10} 0.700000
10 11 {'n_neighbors': 11} 0.733333
11 12 {'n_neighbors': 12} 0.766667
12 13 {'n_neighbors': 13} 0.733333
13 14 {'n_neighbors': 14} 0.733333
14 15 {'n_neighbors': 15} 0.733333
15 16 {'n_neighbors': 16} 0.733333
16 17 {'n_neighbors': 17} 0.733333
17 18 {'n_neighbors': 18} 0.733333
18 19 {'n_neighbors': 19} 0.733333
19 20 {'n_neighbors': 20} 0.733333
20 21 {'n_neighbors': 21} 0.733333
21 22 {'n_neighbors': 22} 0.733333
22 23 {'n_neighbors': 23} 0.733333
23 24 {'n_neighbors': 24} 0.700000
24 25 {'n_neighbors': 25} 0.700000
25 26 {'n_neighbors': 26} 0.700000
26 27 {'n_neighbors': 27} 0.733333
27 28 {'n_neighbors': 28} 0.700000
28 29 {'n_neighbors': 29} 0.733333
split1_test_score split2_test_score split3_test_score \
0 0.733333 0.666667 0.833333
1 0.733333 0.666667 0.766667
2 0.800000 0.633333 0.866667
3 0.800000 0.733333 0.800000
4 0.766667 0.733333 0.866667
5 0.866667 0.833333 0.900000
6 0.833333 0.800000 0.866667
7 0.800000 0.766667 0.866667
8 0.766667 0.766667 0.866667
9 0.766667 0.800000 0.833333
10 0.766667 0.733333 0.833333
11 0.800000 0.700000 0.833333
12 0.766667 0.733333 0.833333
13 0.733333 0.700000 0.866667
14 0.800000 0.733333 0.866667
15 0.833333 0.766667 0.900000
16 0.800000 0.766667 0.933333
17 0.800000 0.766667 0.866667
18 0.766667 0.766667 0.866667
19 0.833333 0.766667 0.800000
20 0.800000 0.766667 0.866667
21 0.833333 0.833333 0.833333
22 0.833333 0.800000 0.866667
23 0.800000 0.833333 0.833333
24 0.833333 0.800000 0.833333
25 0.833333 0.800000 0.833333
26 0.833333 0.733333 0.866667
27 0.833333 0.733333 0.833333
28 0.800000 0.733333 0.866667
split4_test_score mean_test_score std_test_score rank_test_score
0 0.666667 0.726667 0.061101 26
1 0.633333 0.700000 0.047140 29
2 0.666667 0.726667 0.090431 26
3 0.600000 0.720000 0.077746 28
4 0.766667 0.766667 0.055777 23
5 0.733333 0.813333 0.068638 2
6 0.733333 0.793333 0.053333 8
7 0.700000 0.773333 0.057349 19
8 0.733333 0.773333 0.048990 19
9 0.733333 0.766667 0.047140 23
10 0.800000 0.773333 0.038873 19
11 0.766667 0.773333 0.044222 19
12 0.833333 0.780000 0.045216 18
13 0.733333 0.753333 0.058119 25
14 0.800000 0.786667 0.049889 14
15 0.766667 0.800000 0.059628 5
16 0.866667 0.820000 0.071802 1
17 0.800000 0.793333 0.044222 8
18 0.833333 0.793333 0.048990 8
19 0.800000 0.786667 0.033993 14
20 0.766667 0.786667 0.045216 14
21 0.733333 0.793333 0.048990 8
22 0.833333 0.813333 0.045216 2
23 0.833333 0.800000 0.051640 5
24 0.866667 0.806667 0.057349 4
25 0.800000 0.793333 0.048990 8
26 0.833333 0.800000 0.055777 5
27 0.833333 0.786667 0.058119 14
28 0.833333 0.793333 0.053333 8
In [35]: print(grid.cv_results_)
{'mean_fit_time': array([0.001196 , 0.00139551, 0.00159965, 0.00178361, 0.00120277,
0.00099864, 0.0009851 , 0.0009975 , 0.0009974 , 0.00079842,
0.00098987, 0.00079837, 0.00099764, 0.00099716, 0.0009973 ,
0.00099835, 0.00079751, 0.00099001, 0.00080519, 0.00099697,
0.00099769, 0.0007977 , 0.00099735, 0.00079775, 0.00099697,
0.00098243, 0.00100498, 0.0009973 , 0.00079083]), 'std_fit_time':
array([4.00747083e-04, 4.97437049e-04, 8.09328245e-04, 3.93005211e-04,
4.26855489e-04, 2.88712988e-06, 2.34865415e-05, 2.08140375e-05,
2.43140197e-07, 3.99210774e-04, 1.36889642e-05, 3.99187143e-04,
1.72322378e-06, 1.14440918e-06, 6.21719590e-07, 9.60800251e-07,
3.98755797e-04, 1.43595296e-05, 4.02811513e-04, 2.04873572e-06,
2.18366430e-05, 3.98855440e-04, 5.09122765e-07, 3.98876264e-04,
3.02180853e-06, 4.10303343e-05, 1.44312385e-05, 3.17374445e-06,
3.95645988e-04]), 'mean_score_time': array([0.00119882, 0.00158925, 0.00079722,
0.0012157 , 0.00119586,
0.00099535, 0.00059791, 0.00100389, 0.00099802, 0.00079751,
0.0010046 , 0.00059767, 0.00079808, 0.00099688, 0.0006052 ,
0.00098953, 0.00080452, 0.00079851, 0.00098991, 0.00079889,
0.00099697, 0.00079818, 0.0009973 , 0.00100989, 0.00099688,
0.00101218, 0.000985 , 0.00099711, 0.00101242]), 'std_score_time':
array([3.98087535e-04, 8.01915004e-04, 3.98615655e-04, 4.06150958e-04,
3.97826513e-04, 3.63590756e-05, 4.88190680e-04, 1.30583752e-05,
2.13881532e-05, 3.98755968e-04, 1.34266709e-05, 4.88496859e-04,
3.99043637e-04, 2.12662958e-05, 4.94298713e-04, 1.29587461e-05,
4.02471736e-04, 3.99257191e-04, 1.31212555e-05, 3.99448380e-04,
8.06404806e-07, 3.99089127e-04, 2.11201723e-05, 2.36356400e-05,
1.84308511e-06, 2.27699604e-05, 2.40392146e-05, 1.25617408e-06,
1.68057872e-05]), 'param_n_neighbors': masked_array(data=[1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False],
fill_value='?',
dtype=object), 'params': [{'n_neighbors': 1}, {'n_neighbors': 2},
{'n_neighbors': 3}, {'n_neighbors': 4}, {'n_neighbors': 5},
{'n_neighbors': 6}, {'n_neighbors': 7}, {'n_neighbors': 8},
{'n_neighbors': 9}, {'n_neighbors': 10}, {'n_neighbors': 11},
{'n_neighbors': 12}, {'n_neighbors': 13}, {'n_neighbors': 14},
{'n_neighbors': 15}, {'n_neighbors': 16}, {'n_neighbors': 17},
{'n_neighbors': 18}, {'n_neighbors': 19}, {'n_neighbors': 20},
{'n_neighbors': 21}, {'n_neighbors': 22}, {'n_neighbors': 23},
{'n_neighbors': 24}, {'n_neighbors': 25}, {'n_neighbors': 26},
{'n_neighbors': 27}, {'n_neighbors': 28}, {'n_neighbors': 29}],
'split0_test_score':
array([0.73333333, 0.7 , 0.66666667, 0.66666667, 0.7 ,
0.73333333, 0.73333333, 0.73333333, 0.73333333, 0.7 ,
0.73333333, 0.76666667, 0.73333333, 0.73333333, 0.73333333,
0.73333333, 0.73333333, 0.73333333, 0.73333333, 0.73333333,
0.73333333, 0.73333333, 0.73333333, 0.7 , 0.7 ,
0.7 , 0.73333333, 0.7 , 0.73333333]), 'split1_test_score':
array([0.73333333, 0.73333333, 0.8 , 0.8 , 0.76666667,
0.86666667, 0.83333333, 0.8 , 0.76666667, 0.76666667,
0.76666667, 0.8 , 0.76666667, 0.73333333, 0.8 ,
0.83333333, 0.8 , 0.8 , 0.76666667, 0.83333333,
0.8 , 0.83333333, 0.83333333, 0.8 , 0.83333333,
0.83333333, 0.83333333, 0.83333333, 0.8 ]), 'split2_test_score':
array([0.66666667, 0.66666667, 0.63333333, 0.73333333, 0.73333333,
0.83333333, 0.8 , 0.76666667, 0.76666667, 0.8 ,
0.73333333, 0.7 , 0.73333333, 0.7 , 0.73333333,
0.76666667, 0.76666667, 0.76666667, 0.76666667, 0.76666667,
0.76666667, 0.83333333, 0.8 , 0.83333333, 0.8 ,
0.8 , 0.73333333, 0.73333333, 0.73333333]), 'split3_test_score':
array([0.83333333, 0.76666667, 0.86666667, 0.8 , 0.86666667,
0.9 , 0.86666667, 0.86666667, 0.86666667, 0.83333333,
0.83333333, 0.83333333, 0.83333333, 0.86666667, 0.86666667,
0.9 , 0.93333333, 0.86666667, 0.86666667, 0.8 ,
0.86666667, 0.83333333, 0.86666667, 0.83333333, 0.83333333,
0.83333333, 0.86666667, 0.83333333, 0.86666667]), 'split4_test_score':
array([0.66666667, 0.63333333, 0.66666667, 0.6 , 0.76666667,
0.73333333, 0.73333333, 0.7 , 0.73333333, 0.73333333,
0.8 , 0.76666667, 0.83333333, 0.73333333, 0.8 ,
0.76666667, 0.86666667, 0.8 , 0.83333333, 0.8 ,
0.76666667, 0.73333333, 0.83333333, 0.83333333, 0.86666667,
0.8 , 0.83333333, 0.83333333, 0.83333333]), 'mean_test_score':
array([0.72666667, 0.7 , 0.72666667, 0.72 , 0.76666667,
0.81333333, 0.79333333, 0.77333333, 0.77333333, 0.76666667,
0.77333333, 0.77333333, 0.78 , 0.75333333, 0.78666667,
0.8 , 0.82 , 0.79333333, 0.79333333, 0.78666667,
0.78666667, 0.79333333, 0.81333333, 0.8 , 0.80666667,
0.79333333, 0.8 , 0.78666667, 0.79333333]), 'std_test_score':
array([0.06110101, 0.04714045, 0.09043107, 0.07774603, 0.05577734,
0.06863753, 0.05333333, 0.05734884, 0.04898979, 0.04714045,
0.03887301, 0.04422166, 0.04521553, 0.05811865, 0.04988877,
0.05962848, 0.0718022 , 0.04422166, 0.04898979, 0.03399346,
0.04521553, 0.04898979, 0.04521553, 0.05163978, 0.05734884,
0.04898979, 0.05577734, 0.05811865, 0.05333333]), 'rank_test_score':
array([26, 29, 26, 28, 23, 2, 8, 19, 19, 23, 19, 19, 18, 25, 14, 5, 1,
8, 8, 14, 14, 8, 2, 5, 4, 8, 5, 14, 8])}
In [39]: grid_mean_scores= grid.cv_results_['mean_test_score']
print(grid_mean_scores,'\n')
plt.figure()
plt.xlabel('Tuning Parameter: N nearest neighbors')
plt.ylabel('Classification Accuracy')
plt.plot(k_range,grid_mean_scores)
print('最高得分是近邻值取 k =',grid.best_params_['n_neighbors'],'时的得分'
,grid.best_score_)
plt.plot(grid.best_params_['n_neighbors'],grid.best_score_,'ro',
markersize=12,markeredgewidth=1.5,
markerfacecolor='None',markeredgecolor='r')
[0.72666667 0.7 0.72666667 0.72 0.76666667 0.81333333
0.79333333 0.77333333 0.77333333 0.76666667 0.77333333 0.77333333
0.78 0.75333333 0.78666667 0.8 0.82 0.79333333
0.79333333 0.78666667 0.78666667 0.79333333 0.81333333 0.8
0.80666667 0.79333333 0.8 0.78666667 0.79333333]
最高得分是近邻值取 k = 17 时的得分 0.82
Out[39]: [<matplotlib.lines.Line2D at 0x2042e7b5518>]