试试Scikit-learn
sklearn.neighbors.KNeighborsClassifier
- n_neighbors: 临近点个数,即k的个数,默认是5
- p: 距离度量,默认
- algorithm: 近邻算法,可选{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
- weights: 确定近邻的权重
- n_neighbors : int,optional(default = 5)
默认情况下kneighbors查询使用的邻居数。就是k-NN的k的值,选取最近的k个点。 - weights : str或callable,可选(默认=‘uniform’)
默认是uniform,参数可以是uniform、distance,也可以是用户自己定义的函数。uniform是均等的权重,就说所有的邻近点的权重都是相等的。distance是不均等的权重,距离近的点比距离远的点的影响大。用户自定义的函数,接收距离的数组,返回一组维数相同的权重。
algorithm : {‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选
快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。
- 以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。
- leaf_size : int,optional(默认值= 30)
默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。 - p : 整数,可选(默认= 2)
距离度量公式。在上小结,我们使用欧氏距离公式进行距离度量。除此之外,还有其他的度量方法,例如曼哈顿距离。这个参数默认为2,也就是默认使用欧式距离公式进行距离度量。也可以设置为1,使用曼哈顿距离公式进行距离度量。 - metric : 字符串或可调用,默认为’minkowski’
用于距离度量,默认度量是minkowski,也就是p=2的欧氏距离(欧几里德度量)。 - metric_params : dict,optional(默认=None)
距离公式的其他关键参数,这个可以不管,使用默认的None即可。 - n_jobs : int或None,可选(默认=None)
并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。
# 1导入模块 from sklearn.neighbors import KNeighborsClassifier # 2创建KNN近邻实例 knn=KNeighborsClassifier(n_neighbors=4) # 3 拟合该模型 knn.fit(X_train,y_train) # 4 得到分数 knn.score(X_test,y_test)
1.0
试试其他的近邻数量
# 1导入模块 from sklearn.neighbors import KNeighborsClassifier # 2创建KNN近邻实例 knn=KNeighborsClassifier(n_neighbors=2) # 3 拟合该模型 knn.fit(X_train,y_train) # 4 得到分数 knn.score(X_test,y_test)
1.0
# 1导入模块 from sklearn.neighbors import KNeighborsClassifier # 2创建KNN近邻实例 knn=KNeighborsClassifier(n_neighbors=6) # 3 拟合该模型 knn.fit(X_train,y_train) # 4 得到分数 knn.score(X_test,y_test)
1.0
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10 from sklearn.model_selection import train_test_split def getBestK(X_train,y_train,K): best_score=0 best_k=1 best_model=knn=KNeighborsClassifier(1) X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train,random_state=0) for num in range(1,K): knn=KNeighborsClassifier(num) knn.fit(X_train_set,y_train_set) score=round(knn.score(X_val,y_val),2) print(score,num) if score>best_score: best_k=num best_score=score best_model=knn return best_k,best_score,best_model
best_k,best_score,best_model=getBestK(X_train,y_train,11)
0.95 1 0.95 2 0.95 3 0.95 4 0.95 5 1.0 6 1.0 7 1.0 8 1.0 9 1.0 10
#5采用测试集查看经验风险 best_model.score(X_test,y_test)
1.0
上面选择的k是在一次对训练集的划分的验证集上选的参数,具有一定的偶然性,使得最后根据最高验证分数选出来的在测试集上的效果不佳
#6 试试交叉验证误差 from sklearn.model_selection import RepeatedKFold rkf=RepeatedKFold(n_repeats=10,n_splits=5,random_state=42) for i,(train_index,test_index) in enumerate(rkf.split(X_train)): print("train_index",train_index) print("test_index",test_index) # print("新的训练数据为",X_train[train_index],y_train[train_index]) # print("新的验证数据为",X_train[test_index],y_train[test_index])
train_index [ 1 2 3 5 6 7 8 11 13 14 15 16 17 19 20 21 22 23 24 25 26 27 29 30 31 32 33 36 37 38 39 40 41 43 44 45 46 47 48 50 51 52 53 54 55 56 57 58 59 60 62 65 66 67 68 70 71 72 73 74] test_index [ 0 4 9 10 12 18 28 34 35 42 49 61 63 64 69] train_index [ 0 1 2 3 4 6 8 9 10 11 12 13 14 15 17 18 19 20 21 23 24 25 26 27 28 29 32 34 35 36 37 38 41 42 43 46 48 49 50 51 52 53 54 55 57 59 60 61 62 63 64 65 67 68 69 70 71 72 73 74] test_index [ 5 7 16 22 30 31 33 39 40 44 45 47 56 58 66] train_index [ 0 1 2 4 5 7 9 10 11 12 14 15 16 18 20 21 22 23 24 26 27 28 29 30 31 32 33 34 35 37 39 40 41 42 43 44 45 46 47 48 49 51 52 55 56 57 58 59 60 61 63 64 65 66 67 68 69 70 71 73] test_index [ 3 6 8 13 17 19 25 36 38 50 53 54 62 72 74] train_index [ 0 1 2 3 4 5 6 7 8 9 10 12 13 14 16 17 18 19 20 21 22 23 25 28 29 30 31 33 34 35 36 37 38 39 40 42 44 45 47 49 50 51 52 53 54 56 58 59 60 61 62 63 64 65 66 69 70 71 72 74] test_index [11 15 24 26 27 32 41 43 46 48 55 57 67 68 73] train_index [ 0 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 22 24 25 26 27 28 30 31 32 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 53 54 55 56 57 58 61 62 63 64 66 67 68 69 72 73 74] test_index [ 1 2 14 20 21 23 29 37 51 52 59 60 65 70 71] train_index [ 0 2 3 4 6 7 8 9 10 11 12 13 14 16 18 19 21 22 23 24 25 26 27 28 30 32 33 34 35 36 37 38 39 40 41 42 43 44 47 48 50 52 53 54 55 56 57 58 59 61 62 64 65 66 67 68 70 71 72 73] test_index [ 1 5 15 17 20 29 31 45 46 49 51 60 63 69 74] train_index [ 0 1 2 4 5 6 7 8 10 11 13 14 15 16 17 20 21 22 23 25 26 27 28 29 31 32 33 34 35 36 38 39 40 41 43 44 45 46 49 50 51 52 53 54 55 56 57 59 60 61 62 63 64 65 66 69 70 71 73 74] test_index [ 3 9 12 18 19 24 30 37 42 47 48 58 67 68 72] train_index [ 0 1 3 4 5 6 7 8 9 10 11 12 14 15 16 17 18 19 20 23 24 25 27 28 29 30 31 32 34 37 38 40 41 42 43 44 45 46 47 48 49 50 51 52 56 57 58 59 60 62 63 64 65 67 68 69 70 72 73 74] test_index [ 2 13 21 22 26 33 35 36 39 53 54 55 61 66 71] train_index [ 0 1 2 3 5 7 8 9 10 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28 29 30 31 33 35 36 37 39 40 42 43 44 45 46 47 48 49 51 52 53 54 55 58 59 60 61 63 64 66 67 68 69 71 72 73 74] test_index [ 4 6 11 16 27 32 34 38 41 50 56 57 62 65 70] train_index [ 1 2 3 4 5 6 9 11 12 13 15 16 17 18 19 20 21 22 24 26 27 29 30 31 32 33 34 35 36 37 38 39 41 42 45 46 47 48 49 50 51 53 54 55 56 57 58 60 61 62 63 65 66 67 68 69 70 71 72 74] test_index [ 0 7 8 10 14 23 25 28 40 43 44 52 59 64 73] train_index [ 0 1 2 3 4 5 7 8 10 11 14 16 18 19 20 21 22 23 24 25 26 27 28 29 31 32 35 36 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58 61 62 63 64 66 67 68 69 71 72 73 74] test_index [ 6 9 12 13 15 17 30 33 34 37 44 59 60 65 70] train_index [ 0 1 2 5 6 7 8 9 11 12 13 14 15 16 17 18 20 22 23 26 27 29 30 31 32 33 34 36 37 38 40 41 43 44 45 47 48 50 51 53 54 55 56 57 58 59 60 61 63 64 65 66 67 68 69 70 71 72 73 74] test_index [ 3 4 10 19 21 24 25 28 35 39 42 46 49 52 62] train_index [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 17 19 21 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 49 50 51 52 53 59 60 61 62 63 65 66 68 69 70 71 73 74] test_index [16 18 20 22 45 47 48 54 55 56 57 58 64 67 72] train_index [ 0 2 3 4 5 6 7 9 10 12 13 15 16 17 18 19 20 21 22 24 25 26 27 28 29 30 33 34 35 37 38 39 42 43 44 45 46 47 48 49 52 54 55 56 57 58 59 60 61 62 64 65 66 67 68 69 70 72 73 74] test_index [ 1 8 11 14 23 31 32 36 40 41 50 51 53 63 71] train_index [ 1 3 4 6 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 28 30 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 62 63 64 65 67 70 71 72] test_index [ 0 2 5 7 26 27 29 38 43 61 66 68 69 73 74] train_index [ 0 1 2 3 4 6 7 8 10 11 13 15 17 18 19 20 21 22 23 24 25 27 28 29 30 31 32 33 34 36 37 38 39 40 41 44 45 46 47 48 49 51 52 53 54 55 56 57 59 60 61 66 67 68 69 70 71 72 73 74] test_index [ 5 9 12 14 16 26 35 42 43 50 58 62 63 64 65] train_index [ 0 1 2 4 5 6 7 8 9 10 11 12 14 15 16 18 19 22 23 24 25 26 29 30 31 32 34 35 36 37 38 39 40 41 42 43 44 47 48 49 50 51 55 56 57 58 59 62 63 64 65 66 67 68 69 70 71 72 73 74] test_index [ 3 13 17 20 21 27 28 33 45 46 52 53 54 60 61] train_index [ 0 1 3 4 5 6 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 25 26 27 28 29 30 31 32 33 34 35 36 38 39 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 58 60 61 62 63 64 65 67 70 71] test_index [ 2 7 23 24 37 40 44 57 59 66 68 69 72 73 74] train_index [ 0 2 3 5 7 9 10 12 13 14 16 17 18 19 20 21 22 23 24 26 27 28 29 30 32 33 35 37 38 39 40 41 42 43 44 45 46 49 50 51 52 53 54 56 57 58 59 60 61 62 63 64 65 66 68 69 70 72 73 74] test_index [ 1 4 6 8 11 15 25 31 34 36 47 48 55 67 71] train_index [ 1 2 3 4 5 6 7 8 9 11 12 13 14 15 16 17 20 21 23 24 25 26 27 28 31 33 34 35 36 37 40 42 43 44 45 46 47 48 50 52 53 54 55 57 58 59 60 61 62 63 64 65 66 67 68 69 71 72 73 74] test_index [ 0 10 18 19 22 29 30 32 38 39 41 49 51 56 70] train_index [ 0 1 2 3 4 5 7 8 9 13 14 16 17 18 20 21 22 23 24 25 26 27 28 29 30 31 32 34 35 36 37 38 40 41 42 43 44 45 46 47 48 50 53 54 56 59 60 61 63 64 65 66 67 68 69 70 71 72 73 74] test_index [ 6 10 11 12 15 19 33 39 49 51 52 55 57 58 62] train_index [ 2 3 4 5 6 7 10 11 12 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 36 37 39 40 42 43 45 46 47 48 49 50 51 52 53 55 56 57 58 59 60 61 62 63 64 65 66 67 69 72 74] test_index [ 0 1 8 9 13 14 35 38 41 44 54 68 70 71 73] train_index [ 0 1 3 4 5 6 7 8 9 10 11 12 13 14 15 16 18 19 20 26 27 28 29 32 33 34 35 36 37 38 39 40 41 43 44 45 47 48 49 50 51 52 53 54 55 56 57 58 59 60 62 63 65 66 68 69 70 71 73 74] test_index [ 2 17 21 22 23 24 25 30 31 42 46 61 64 67 72] train_index [ 0 1 2 4 6 7 8 9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27 29 30 31 32 33 35 37 38 39 41 42 44 46 49 50 51 52 53 54 55 57 58 59 60 61 62 63 64 67 68 69 70 71 72 73 74] test_index [ 3 5 16 18 28 34 36 40 43 45 47 48 56 65 66] train_index [ 0 1 2 3 5 6 8 9 10 11 12 13 14 15 16 17 18 19 21 22 23 24 25 28 30 31 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 51 52 54 55 56 57 58 61 62 64 65 66 67 68 70 71 72 73] test_index [ 4 7 20 26 27 29 32 37 50 53 59 60 63 69 74] train_index [ 0 1 3 4 5 7 8 11 12 13 14 15 16 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 34 35 36 37 38 39 41 42 43 44 45 46 48 50 51 52 54 56 57 58 59 60 62 63 64 65 66 67 69 70 73 74] test_index [ 2 6 9 10 17 33 40 47 49 53 55 61 68 71 72] train_index [ 2 3 4 5 6 7 9 10 12 13 14 15 16 17 18 19 21 24 25 27 29 31 32 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 55 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72] test_index [ 0 1 8 11 20 22 23 26 28 30 37 54 56 73 74] train_index [ 0 1 2 5 6 7 8 9 10 11 13 14 15 17 19 20 21 22 23 24 26 28 30 31 32 33 35 36 37 40 41 42 43 44 46 47 48 49 50 51 53 54 55 56 57 58 59 60 61 62 63 64 65 67 68 70 71 72 73 74] test_index [ 3 4 12 16 18 25 27 29 34 38 39 45 52 66 69] train_index [ 0 1 2 3 4 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 33 34 35 37 38 39 40 44 45 47 49 50 52 53 54 55 56 57 61 62 64 65 66 68 69 71 72 73 74] test_index [ 5 32 36 41 42 43 46 48 51 58 59 60 63 67 70] train_index [ 0 1 2 3 4 5 6 8 9 10 11 12 16 17 18 20 22 23 25 26 27 28 29 30 32 33 34 36 37 38 39 40 41 42 43 45 46 47 48 49 51 52 53 54 55 56 58 59 60 61 63 66 67 68 69 70 71 72 73 74] test_index [ 7 13 14 15 19 21 24 31 35 44 50 57 62 64 65] train_index [ 0 1 2 3 4 6 7 8 9 10 11 12 13 15 16 17 18 19 22 23 24 26 27 28 30 31 32 33 34 35 36 37 38 39 43 44 45 46 47 48 51 52 53 54 55 56 57 59 60 61 62 65 66 67 68 69 70 72 73 74] test_index [ 5 14 20 21 25 29 40 41 42 49 50 58 63 64 71] train_index [ 0 1 2 3 4 5 7 9 11 14 15 18 19 20 21 22 23 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 44 46 47 48 49 50 51 52 53 55 56 57 58 60 61 62 63 64 65 67 68 69 70 71 72] test_index [ 6 8 10 12 13 16 17 24 43 45 54 59 66 73 74] train_index [ 0 1 3 4 5 6 8 9 10 12 13 14 15 16 17 18 20 21 22 23 24 25 28 29 30 31 32 33 35 38 40 41 42 43 44 45 46 47 48 49 50 51 53 54 56 57 58 59 60 61 62 63 64 66 68 69 71 72 73 74] test_index [ 2 7 11 19 26 27 34 36 37 39 52 55 65 67 70] train_index [ 2 4 5 6 7 8 9 10 11 12 13 14 15 16 17 19 20 21 22 24 25 26 27 28 29 32 34 36 37 38 39 40 41 42 43 45 46 47 49 50 52 53 54 55 56 57 58 59 61 63 64 65 66 67 68 70 71 72 73 74] test_index [ 0 1 3 18 23 30 31 33 35 44 48 51 60 62 69] train_index [ 0 1 2 3 5 6 7 8 10 11 12 13 14 16 17 18 19 20 21 23 24 25 26 27 29 30 31 33 34 35 36 37 39 40 41 42 43 44 45 48 49 50 51 52 54 55 58 59 60 62 63 64 65 66 67 69 70 71 73 74] test_index [ 4 9 15 22 28 32 38 46 47 53 56 57 61 68 72] train_index [ 2 3 4 6 8 9 10 11 12 13 14 15 16 18 19 20 21 22 23 24 26 27 29 30 32 33 34 35 36 37 38 39 40 42 44 45 46 47 48 49 50 51 53 54 56 59 60 61 62 63 64 65 66 67 68 70 71 72 73 74] test_index [ 0 1 5 7 17 25 28 31 41 43 52 55 57 58 69] train_index [ 0 1 3 4 5 6 7 8 11 12 13 15 16 17 18 19 20 21 22 23 24 25 27 28 29 30 31 32 34 35 36 40 41 43 44 45 47 48 50 52 53 54 55 56 57 58 59 60 61 63 64 65 67 68 69 70 71 72 73 74] test_index [ 2 9 10 14 26 33 37 38 39 42 46 49 51 62 66] train_index [ 0 1 2 5 7 9 10 11 12 14 16 17 18 19 21 22 23 24 25 26 28 29 31 33 34 35 36 37 38 39 40 41 42 43 46 47 48 49 50 51 52 54 55 56 57 58 59 61 62 63 65 66 67 68 69 70 71 72 73 74] test_index [ 3 4 6 8 13 15 20 27 30 32 44 45 53 60 64] train_index [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 17 20 22 23 24 25 26 27 28 30 31 32 33 34 35 36 37 38 39 41 42 43 44 45 46 48 49 51 52 53 54 55 57 58 60 61 62 63 64 66 68 69 72 73] test_index [16 18 19 21 29 40 47 50 56 59 65 67 70 71 74] train_index [ 0 1 2 3 4 5 6 7 8 9 10 13 14 15 16 17 18 19 20 21 25 26 27 28 29 30 31 32 33 37 38 39 40 41 42 43 44 45 46 47 49 50 51 52 53 55 56 57 58 59 60 62 64 65 66 67 69 70 71 74] test_index [11 12 22 23 24 34 35 36 48 54 61 63 68 72 73] train_index [ 0 2 3 4 5 7 8 9 10 12 13 14 15 16 17 18 19 20 22 24 25 26 27 28 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 47 48 49 51 52 53 57 58 59 60 61 62 63 64 65 66 67 69 73 74] test_index [ 1 6 11 21 23 29 45 50 54 55 56 68 70 71 72] train_index [ 0 1 2 3 4 5 6 7 9 10 11 12 15 16 18 19 20 21 23 24 25 26 27 28 29 30 31 32 34 35 36 37 38 39 40 43 44 45 46 48 49 50 51 52 53 54 55 56 57 59 60 63 64 65 66 68 69 70 71 72] test_index [ 8 13 14 17 22 33 41 42 47 58 61 62 67 73 74] train_index [ 1 2 3 4 5 6 7 8 9 11 12 13 14 16 17 18 19 21 22 23 25 26 27 28 29 30 33 35 36 37 38 41 42 43 44 45 47 48 50 53 54 55 56 57 58 59 60 61 62 64 65 66 67 68 69 70 71 72 73 74] test_index [ 0 10 15 20 24 31 32 34 39 40 46 49 51 52 63] train_index [ 0 1 3 4 5 6 7 8 10 11 13 14 15 16 17 18 20 21 22 23 24 28 29 30 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 49 50 51 52 54 55 56 58 59 61 62 63 64 65 67 68 70 71 72 73 74] test_index [ 2 9 12 19 25 26 27 38 43 48 53 57 60 66 69] train_index [ 0 1 2 6 8 9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27 29 31 32 33 34 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58 60 61 62 63 66 67 68 69 70 71 72 73 74] test_index [ 3 4 5 7 16 18 28 30 35 36 37 44 59 64 65] train_index [ 0 1 2 4 5 9 10 12 15 16 17 18 19 20 21 22 24 25 26 27 28 29 30 31 32 33 34 36 38 39 40 41 42 44 45 46 47 48 49 50 51 52 54 55 56 57 58 59 60 61 62 63 64 65 66 68 69 71 72 73] test_index [ 3 6 7 8 11 13 14 23 35 37 43 53 67 70 74] train_index [ 0 1 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26 27 28 29 31 32 33 34 35 37 40 42 43 44 45 46 47 49 50 53 54 55 56 57 58 59 60 61 62 63 65 67 68 69 70 72 74] test_index [ 2 18 19 30 36 38 39 41 48 51 52 64 66 71 73] train_index [ 0 1 2 3 4 5 6 7 8 9 11 12 13 14 16 17 18 19 23 24 26 27 28 29 30 32 34 35 36 37 38 39 40 41 43 44 45 46 48 49 50 51 52 53 56 57 58 59 60 62 63 64 65 66 67 70 71 72 73 74] test_index [10 15 20 21 22 25 31 33 42 47 54 55 61 68 69] train_index [ 2 3 6 7 8 10 11 12 13 14 15 16 18 19 20 21 22 23 25 26 27 28 30 31 32 33 34 35 36 37 38 39 40 41 42 43 45 47 48 49 51 52 53 54 55 57 59 60 61 62 63 64 66 67 68 69 70 71 73 74] test_index [ 0 1 4 5 9 17 24 29 44 46 50 56 58 65 72] train_index [ 0 1 2 3 4 5 6 7 8 9 10 11 13 14 15 17 18 19 20 21 22 23 24 25 29 30 31 33 35 36 37 38 39 41 42 43 44 46 47 48 50 51 52 53 54 55 56 58 61 64 65 66 67 68 69 70 71 72 73 74] test_index [12 16 26 27 28 32 34 40 45 49 57 59 60 62 63]
from sklearn.model_selection import cross_validate cross_validate(knn,X_train,y_train,cv=rkf,scoring="accuracy",return_estimator=True)
{'fit_time': array([0.00099969, 0. , 0.00099897, 0. , 0. , 0.00100088, 0.00100112, 0. , 0. , 0. , 0. , 0. , 0.00099134, 0.00101256, 0.00099635, 0. , 0. , 0. , 0.00099874, 0. , 0.00105643, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00100422, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]), 'score_time': array([0.00099945, 0.00100017, 0. , 0.00099826, 0.0010016 , 0.00099826, 0.00112462, 0.00212598, 0.00103188, 0.00099683, 0.0009737 , 0.00103641, 0. , 0. , 0. , 0.00097394, 0.00102925, 0.00099778, 0. , 0.00100136, 0. , 0. , 0. , 0. , 0. , 0.00100565, 0.00099897, 0. , 0.00099373, 0.00099897, 0.00100088, 0.00106072, 0.00103712, 0.00107408, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.00101113, 0.0010767 , 0.00099373, 0.00093102]), 'estimator': [KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6), KNeighborsClassifier(n_neighbors=6)], 'test_score': array([1. , 1. , 1. , 1. , 0.93333333, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.93333333, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ])}
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10 from sklearn.model_selection import train_test_split from sklearn.model_selection import cross_validate def getBestK(X_train,y_train,K): best_score=0 best_k=1 # X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train) rkf=RepeatedKFold(n_repeats=5,n_splits=5,random_state=42) for num in range(1,K): knn=KNeighborsClassifier(num) result=cross_validate(knn,X_train,y_train,cv=rkf,scoring="f1") score=result["test_score"].mean() score=round(score,2) print(score,num) if score>best_score: best_k=num best_score=score return best_k,best_score
best_k,best_score=getBestK(X_train,y_train,15) best_k,best_score
0.98 1 0.99 2 0.99 3 0.99 4 0.99 5 0.99 6 1.0 7 0.99 8 0.99 9 0.98 10 0.98 11 0.97 12 0.98 13 0.97 14 (7, 1.0)
knn=KNeighborsClassifier(best_k) knn.fit(X_train,y_train) knn.score(X_test,y_test)
1.0
自动调参吧,试试循环,找到最优的k值
实验:试试用KNN完成回归任务
1 准备数据
import numpy as np x1=np.linspace(-10,10,100) x2=np.linspace(-5,15,100)
#手动构造一些数据 w1=5 w2=4 y=x1*w1+x2*w2 y
array([-70. , -68.18181818, -66.36363636, -64.54545455, -62.72727273, -60.90909091, -59.09090909, -57.27272727, -55.45454545, -53.63636364, -51.81818182, -50. , -48.18181818, -46.36363636, -44.54545455, -42.72727273, -40.90909091, -39.09090909, -37.27272727, -35.45454545, -33.63636364, -31.81818182, -30. , -28.18181818, -26.36363636, -24.54545455, -22.72727273, -20.90909091, -19.09090909, -17.27272727, -15.45454545, -13.63636364, -11.81818182, -10. , -8.18181818, -6.36363636, -4.54545455, -2.72727273, -0.90909091, 0.90909091, 2.72727273, 4.54545455, 6.36363636, 8.18181818, 10. , 11.81818182, 13.63636364, 15.45454545, 17.27272727, 19.09090909, 20.90909091, 22.72727273, 24.54545455, 26.36363636, 28.18181818, 30. , 31.81818182, 33.63636364, 35.45454545, 37.27272727, 39.09090909, 40.90909091, 42.72727273, 44.54545455, 46.36363636, 48.18181818, 50. , 51.81818182, 53.63636364, 55.45454545, 57.27272727, 59.09090909, 60.90909091, 62.72727273, 64.54545455, 66.36363636, 68.18181818, 70. , 71.81818182, 73.63636364, 75.45454545, 77.27272727, 79.09090909, 80.90909091, 82.72727273, 84.54545455, 86.36363636, 88.18181818, 90. , 91.81818182, 93.63636364, 95.45454545, 97.27272727, 99.09090909, 100.90909091, 102.72727273, 104.54545455, 106.36363636, 108.18181818, 110. ])
x1=x1.reshape(len(x1),1) x2=x2.reshape(len(x2),1) y=y.reshape(len(y),1)
import pandas as pd data=np.hstack([x1,x2,y])
# 给数据加点噪声 np.random.seed=10 data=data+np.random.normal(0.1,1,[100,3]) data
array([[-9.80997918e+00, -4.47671228e+00, -6.86113562e+01], [-9.07863100e+00, -3.29030887e+00, -6.75412089e+01], [-8.17535392e+00, -4.85515660e+00, -6.56682184e+01], [-9.33603110e+00, -4.67304042e+00, -6.39943055e+01], [-8.31454149e+00, -3.61401814e+00, -6.15552168e+01], [-9.35462761e+00, -3.99216837e+00, -6.16450829e+01], [-7.35641032e+00, -5.10713257e+00, -5.80574405e+01], [-7.75808720e+00, -2.81374154e+00, -5.72785817e+01], [-7.85420726e+00, -3.25192460e+00, -5.58260703e+01], [-7.79785201e+00, -4.59268755e+00, -5.46208629e+01], [-9.90411101e+00, -7.55985286e-01, -5.19239440e+01], [-4.91167456e+00, -1.48242138e+00, -5.06778041e+01], [-9.25608953e+00, -1.12391146e+00, -4.80701720e+01], [-6.92987717e+00, -3.58106474e+00, -4.58459514e+01], [-7.19890084e+00, -2.10260074e+00, -4.46497119e+01], [-8.56812108e+00, -2.45314063e+00, -4.19130070e+01], [-6.97527315e+00, -3.25615055e+00, -4.15373469e+01], [-6.09201512e+00, -1.07060626e+00, -4.05034362e+01], [-5.94248008e+00, 6.42232477e-01, -3.64281226e+01], [-5.99567467e+00, -2.26531046e+00, -3.32873129e+01], [-7.56906953e+00, -6.81005515e-01, -3.42368449e+01], [-6.54272630e+00, -7.32829423e-01, -3.18556358e+01], [-4.68241322e+00, -1.55653397e+00, -2.99105801e+01], [-5.61148642e+00, -1.96269845e+00, -2.80144819e+01], [-4.64818297e+00, 2.21684956e-01, -2.56420739e+01], [-5.64237828e+00, -5.05215614e-02, -2.44150985e+01], [-4.77269716e+00, 3.12543954e-01, -2.35962190e+01], [-3.93579614e+00, 3.14368041e-01, -2.04078436e+01], [-4.67599369e+00, 1.38646098e+00, -1.95569688e+01], [-4.56613680e+00, 2.18761537e-01, -1.76443732e+01], [-4.12462083e+00, 7.81731566e-01, -1.55500903e+01], [-5.00893448e+00, 8.43167883e-01, -1.37904298e+01], [-3.32575389e+00, 8.87284515e-01, -1.16870554e+01], [-4.60962500e+00, 2.47674165e+00, -9.43497025e+00], [-2.55399230e+00, 1.60304976e+00, -7.30116575e+00], [-3.92552974e+00, 2.02861216e+00, -8.47211685e+00], [-2.85445054e+00, 1.32252697e+00, -2.27221086e+00], [-3.20383909e+00, 1.56885433e+00, -1.46024067e+00], [-1.87732669e+00, 1.18972183e+00, -1.68276177e+00], [-1.35842429e+00, 3.76086938e+00, 3.35135047e-01], [-7.24957523e-01, 4.37716480e+00, 1.17352349e+00], [-3.70453016e+00, 5.08438460e+00, 3.35207490e+00], [-7.97872551e-01, 2.78241431e+00, 5.09073378e+00], [-3.08232423e+00, 4.21925884e+00, 7.90719675e+00], [ 5.28844300e-01, 4.16412164e+00, 1.01885052e+01], [-2.64895900e-02, 4.04451188e+00, 1.32964325e+01], [ 7.67644414e-01, 4.38295411e+00, 1.20330676e+01], [-3.17298624e-01, 5.52193479e+00, 1.44587349e+01], [-4.05576007e-01, 6.15916945e+00, 1.77192591e+01], [ 2.58635850e-01, 4.36652636e+00, 2.08469868e+01], [-1.15875757e+00, 5.86049204e+00, 2.12312972e+01], [-7.16862753e-01, 7.60609045e+00, 2.24464377e+01], [ 1.00827677e+00, 7.13593566e+00, 2.60236434e+01], [ 8.64304920e-01, 7.70071685e+00, 2.67335947e+01], [ 3.14401551e+00, 5.74841619e+00, 2.76627520e+01], [-1.18085370e-02, 5.45967297e+00, 3.01731518e+01], [ 9.67211352e-01, 6.30044676e+00, 3.31847137e+01], [ 1.32254229e+00, 6.51216091e+00, 3.31636096e+01], [ 9.66206984e-01, 8.15352634e+00, 3.54552668e+01], [ 1.50374715e+00, 8.38063421e+00, 3.82675089e+01], [ 1.20333031e+00, 8.30155252e+00, 4.05759780e+01], [ 2.84702572e+00, 7.44997601e+00, 4.16313092e+01], [ 2.82319554e+00, 7.03396275e+00, 4.33733979e+01], [ 3.88755763e+00, 9.63373825e+00, 4.63550733e+01], [ 3.31979805e+00, 1.00825563e+01, 4.66602506e+01], [ 3.67714879e+00, 8.98817386e+00, 4.71815191e+01], [ 5.61673924e+00, 8.83321195e+00, 4.90218726e+01], [ 4.64376606e+00, 1.05003123e+01, 5.16821640e+01], [ 3.38312917e+00, 9.93985678e+00, 5.44523927e+01], [ 2.90435391e+00, 8.76211593e+00, 5.72974806e+01], [ 1.94362594e+00, 8.37086325e+00, 5.69748221e+01], [ 4.86357671e+00, 8.79920772e+00, 5.92178403e+01], [ 5.21731274e+00, 8.76064972e+00, 6.30249467e+01], [ 5.86040809e+00, 1.12868041e+01, 6.26973140e+01], [ 4.05985223e+00, 8.65847315e+00, 6.61012727e+01], [ 6.19899121e+00, 8.30649111e+00, 6.37680817e+01], [ 5.73989925e+00, 1.00161474e+01, 6.92336558e+01], [ 5.38266361e+00, 1.03971821e+01, 7.17084241e+01], [ 7.23264561e+00, 1.20494918e+01, 7.05362027e+01], [ 6.11948179e+00, 1.19855375e+01, 7.55318286e+01], [ 8.03847795e+00, 9.79749582e+00, 7.47950707e+01], [ 8.30070319e+00, 1.07233637e+01, 7.93806649e+01], [ 7.44456666e+00, 1.11936713e+01, 7.84042566e+01], [ 6.87035796e+00, 1.23168763e+01, 8.01532295e+01], [ 6.57153443e+00, 1.12686434e+01, 8.32735790e+01], [ 8.06216701e+00, 1.26805930e+01, 8.58973008e+01], [ 8.75001919e+00, 1.36698902e+01, 8.72099703e+01], [ 7.30252179e+00, 1.34260600e+01, 8.71816534e+01], [ 1.02174549e+01, 1.12734356e+01, 9.06574864e+01], [ 9.16397441e+00, 1.35946035e+01, 9.12502949e+01], [ 7.65119402e+00, 1.26062408e+01, 9.37067133e+01], [ 7.88012441e+00, 1.20190767e+01, 9.49682650e+01], [ 8.32044954e+00, 1.32807945e+01, 9.65808990e+01], [ 8.01089317e+00, 1.64722621e+01, 9.82354518e+01], [ 9.02271142e+00, 1.33190747e+01, 1.00825525e+02], [ 8.09970303e+00, 1.46680917e+01, 1.03017581e+02], [ 1.13875348e+01, 1.46989516e+01, 1.04003935e+02], [ 1.01333057e+01, 1.33257429e+01, 1.05931984e+02], [ 9.38629399e+00, 1.39040038e+01, 1.10363757e+02], [ 1.13412247e+01, 1.61090392e+01, 1.10731822e+02]])
#将数据拆分成训练数据和测试数据 from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(data[:,:2],data[:,-1]) X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))
2 通过K个近邻预测的标签的距离来预测当前样本的标签
#改写函数 #返回所有近邻的标签的均值作为当前x的预测值 def calcu_distance_return(x,X_train,y_train): KNN_x=[] #遍历训练集中的每个样本 for i in range(X_train.shape[0]): if len(KNN_x)<K: KNN_x.append((euclidean(x,X_train[i]),y_train[i])) else: KNN_x.sort() for j in range(K): if (euclidean(x,X_train[i]))< KNN_x[j][0]: KNN_x[j]=(euclidean(x,X_train[i]),y_train[i]) break knn_label=[item[1] for item in KNN_x] return np.mean(knn_label)
#对整个测试集进行预测 def predict(X_test): y_pred=np.zeros(X_test.shape[0]) for i in range(X_test.shape[0]): y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) y_pred[i]=y_hat_i return y_pred
#输出预测结果 y_pred= predict(X_test) y_pred
array([-48.77391118, -61.82953142, -7.08681066, 31.79119171, 89.89605669, 49.28413251, 52.97713079, 33.48545677, 63.32131747, 98.05154212, -55.78008004, 98.04210317, 7.02443886, -19.02562562, 11.49285143, -13.67585848, 52.97713079, 21.82629113, 10.45687568, 55.14568247, -9.552268 , 94.91846026, -11.51277047, 22.35944142, 86.13169115])
y_test
array([-41.53734685, -58.05744051, -1.46024067, 40.57597798, 103.01758072, 66.10127272, 46.66025056, 56.97482206, 63.0249467 , 100.8255246 , -54.62086294, 91.25029492, 3.3520749 , -23.59621905, 1.17352349, -20.40784363, 46.35507328, 21.23129715, 5.09073378, 59.21784029, 7.90719675, 98.23545178, -1.68276177, 17.71925914, 78.40425661])
3 通过R方进行评估
from sklearn.metrics import r2_score r2_score(y_test,y_pred)
0.9634297760055799