【Python机器学习】实验07 KNN最近邻算法2

简介: 【Python机器学习】实验07 KNN最近邻算法2

试试Scikit-learn

sklearn.neighbors.KNeighborsClassifier
  • n_neighbors: 临近点个数,即k的个数,默认是5
  • p: 距离度量,默认
  • algorithm: 近邻算法,可选{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
  • weights: 确定近邻的权重
  • n_neighbors : int,optional(default = 5)
    默认情况下kneighbors查询使用的邻居数。就是k-NN的k的值,选取最近的k个点。
  • weights : str或callable,可选(默认=‘uniform’)

默认是uniform,参数可以是uniform、distance,也可以是用户自己定义的函数。uniform是均等的权重,就说所有的邻近点的权重都是相等的。distance是不均等的权重,距离近的点比距离远的点的影响大。用户自定义的函数,接收距离的数组,返回一组维数相同的权重。

algorithm : {‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选

快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,brute是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。kd_tree,构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。

  • 以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。ball tree是为了克服kd树高纬失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体。

  • leaf_size : int,optional(默认值= 30)
    默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。
  • p : 整数,可选(默认= 2)
    距离度量公式。在上小结,我们使用欧氏距离公式进行距离度量。除此之外,还有其他的度量方法,例如曼哈顿距离。这个参数默认为2,也就是默认使用欧式距离公式进行距离度量。也可以设置为1,使用曼哈顿距离公式进行距离度量。
  • metric : 字符串或可调用,默认为’minkowski’
    用于距离度量,默认度量是minkowski,也就是p=2的欧氏距离(欧几里德度量)。
  • metric_params : dict,optional(默认=None)
    距离公式的其他关键参数,这个可以不管,使用默认的None即可。
  • n_jobs : int或None,可选(默认=None)
    并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。
# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=4)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0

试试其他的近邻数量

# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=2)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0
# 1导入模块
from sklearn.neighbors import KNeighborsClassifier
# 2创建KNN近邻实例
knn=KNeighborsClassifier(n_neighbors=6)
# 3 拟合该模型
knn.fit(X_train,y_train)
# 4 得到分数
knn.score(X_test,y_test)
1.0
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10
from sklearn.model_selection import train_test_split
def getBestK(X_train,y_train,K):
    best_score=0
    best_k=1
    best_model=knn=KNeighborsClassifier(1)
    X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train,random_state=0)
    for num in range(1,K):
        knn=KNeighborsClassifier(num)
        knn.fit(X_train_set,y_train_set)
        score=round(knn.score(X_val,y_val),2)
        print(score,num)
        if score>best_score:
            best_k=num
            best_score=score
            best_model=knn
    return best_k,best_score,best_model
best_k,best_score,best_model=getBestK(X_train,y_train,11)
0.95 1
0.95 2
0.95 3
0.95 4
0.95 5
1.0 6
1.0 7
1.0 8
1.0 9
1.0 10
#5采用测试集查看经验风险
best_model.score(X_test,y_test)
1.0

上面选择的k是在一次对训练集的划分的验证集上选的参数,具有一定的偶然性,使得最后根据最高验证分数选出来的在测试集上的效果不佳

#6 试试交叉验证误差
from sklearn.model_selection import RepeatedKFold
rkf=RepeatedKFold(n_repeats=10,n_splits=5,random_state=42)
for i,(train_index,test_index) in enumerate(rkf.split(X_train)):
    print("train_index",train_index)
    print("test_index",test_index)
#     print("新的训练数据为",X_train[train_index],y_train[train_index])
#     print("新的验证数据为",X_train[test_index],y_train[test_index])
train_index [ 1  2  3  5  6  7  8 11 13 14 15 16 17 19 20 21 22 23 24 25 26 27 29 30
 31 32 33 36 37 38 39 40 41 43 44 45 46 47 48 50 51 52 53 54 55 56 57 58
 59 60 62 65 66 67 68 70 71 72 73 74]
test_index [ 0  4  9 10 12 18 28 34 35 42 49 61 63 64 69]
train_index [ 0  1  2  3  4  6  8  9 10 11 12 13 14 15 17 18 19 20 21 23 24 25 26 27
 28 29 32 34 35 36 37 38 41 42 43 46 48 49 50 51 52 53 54 55 57 59 60 61
 62 63 64 65 67 68 69 70 71 72 73 74]
test_index [ 5  7 16 22 30 31 33 39 40 44 45 47 56 58 66]
train_index [ 0  1  2  4  5  7  9 10 11 12 14 15 16 18 20 21 22 23 24 26 27 28 29 30
 31 32 33 34 35 37 39 40 41 42 43 44 45 46 47 48 49 51 52 55 56 57 58 59
 60 61 63 64 65 66 67 68 69 70 71 73]
test_index [ 3  6  8 13 17 19 25 36 38 50 53 54 62 72 74]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 16 17 18 19 20 21 22 23 25 28
 29 30 31 33 34 35 36 37 38 39 40 42 44 45 47 49 50 51 52 53 54 56 58 59
 60 61 62 63 64 65 66 69 70 71 72 74]
test_index [11 15 24 26 27 32 41 43 46 48 55 57 67 68 73]
train_index [ 0  3  4  5  6  7  8  9 10 11 12 13 15 16 17 18 19 22 24 25 26 27 28 30
 31 32 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 53 54 55 56 57
 58 61 62 63 64 66 67 68 69 72 73 74]
test_index [ 1  2 14 20 21 23 29 37 51 52 59 60 65 70 71]
train_index [ 0  2  3  4  6  7  8  9 10 11 12 13 14 16 18 19 21 22 23 24 25 26 27 28
 30 32 33 34 35 36 37 38 39 40 41 42 43 44 47 48 50 52 53 54 55 56 57 58
 59 61 62 64 65 66 67 68 70 71 72 73]
test_index [ 1  5 15 17 20 29 31 45 46 49 51 60 63 69 74]
train_index [ 0  1  2  4  5  6  7  8 10 11 13 14 15 16 17 20 21 22 23 25 26 27 28 29
 31 32 33 34 35 36 38 39 40 41 43 44 45 46 49 50 51 52 53 54 55 56 57 59
 60 61 62 63 64 65 66 69 70 71 73 74]
test_index [ 3  9 12 18 19 24 30 37 42 47 48 58 67 68 72]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19 20 23 24 25 27 28
 29 30 31 32 34 37 38 40 41 42 43 44 45 46 47 48 49 50 51 52 56 57 58 59
 60 62 63 64 65 67 68 69 70 72 73 74]
test_index [ 2 13 21 22 26 33 35 36 39 53 54 55 61 66 71]
train_index [ 0  1  2  3  5  7  8  9 10 12 13 14 15 17 18 19 20 21 22 23 24 25 26 28
 29 30 31 33 35 36 37 39 40 42 43 44 45 46 47 48 49 51 52 53 54 55 58 59
 60 61 63 64 66 67 68 69 71 72 73 74]
test_index [ 4  6 11 16 27 32 34 38 41 50 56 57 62 65 70]
train_index [ 1  2  3  4  5  6  9 11 12 13 15 16 17 18 19 20 21 22 24 26 27 29 30 31
 32 33 34 35 36 37 38 39 41 42 45 46 47 48 49 50 51 53 54 55 56 57 58 60
 61 62 63 65 66 67 68 69 70 71 72 74]
test_index [ 0  7  8 10 14 23 25 28 40 43 44 52 59 64 73]
train_index [ 0  1  2  3  4  5  7  8 10 11 14 16 18 19 20 21 22 23 24 25 26 27 28 29
 31 32 35 36 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58
 61 62 63 64 66 67 68 69 71 72 73 74]
test_index [ 6  9 12 13 15 17 30 33 34 37 44 59 60 65 70]
train_index [ 0  1  2  5  6  7  8  9 11 12 13 14 15 16 17 18 20 22 23 26 27 29 30 31
 32 33 34 36 37 38 40 41 43 44 45 47 48 50 51 53 54 55 56 57 58 59 60 61
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 3  4 10 19 21 24 25 28 35 39 42 46 49 52 62]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 17 19 21 23 24 25 26 27
 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 49 50 51 52 53 59
 60 61 62 63 65 66 68 69 70 71 73 74]
test_index [16 18 20 22 45 47 48 54 55 56 57 58 64 67 72]
train_index [ 0  2  3  4  5  6  7  9 10 12 13 15 16 17 18 19 20 21 22 24 25 26 27 28
 29 30 33 34 35 37 38 39 42 43 44 45 46 47 48 49 52 54 55 56 57 58 59 60
 61 62 64 65 66 67 68 69 70 72 73 74]
test_index [ 1  8 11 14 23 31 32 36 40 41 50 51 53 63 71]
train_index [ 1  3  4  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 28 30
 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 48 49 50 51 52 53 54 55 56
 57 58 59 60 62 63 64 65 67 70 71 72]
test_index [ 0  2  5  7 26 27 29 38 43 61 66 68 69 73 74]
train_index [ 0  1  2  3  4  6  7  8 10 11 13 15 17 18 19 20 21 22 23 24 25 27 28 29
 30 31 32 33 34 36 37 38 39 40 41 44 45 46 47 48 49 51 52 53 54 55 56 57
 59 60 61 66 67 68 69 70 71 72 73 74]
test_index [ 5  9 12 14 16 26 35 42 43 50 58 62 63 64 65]
train_index [ 0  1  2  4  5  6  7  8  9 10 11 12 14 15 16 18 19 22 23 24 25 26 29 30
 31 32 34 35 36 37 38 39 40 41 42 43 44 47 48 49 50 51 55 56 57 58 59 62
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 3 13 17 20 21 27 28 33 45 46 52 53 54 60 61]
train_index [ 0  1  3  4  5  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 25 26 27
 28 29 30 31 32 33 34 35 36 38 39 41 42 43 45 46 47 48 49 50 51 52 53 54
 55 56 58 60 61 62 63 64 65 67 70 71]
test_index [ 2  7 23 24 37 40 44 57 59 66 68 69 72 73 74]
train_index [ 0  2  3  5  7  9 10 12 13 14 16 17 18 19 20 21 22 23 24 26 27 28 29 30
 32 33 35 37 38 39 40 41 42 43 44 45 46 49 50 51 52 53 54 56 57 58 59 60
 61 62 63 64 65 66 68 69 70 72 73 74]
test_index [ 1  4  6  8 11 15 25 31 34 36 47 48 55 67 71]
train_index [ 1  2  3  4  5  6  7  8  9 11 12 13 14 15 16 17 20 21 23 24 25 26 27 28
 31 33 34 35 36 37 40 42 43 44 45 46 47 48 50 52 53 54 55 57 58 59 60 61
 62 63 64 65 66 67 68 69 71 72 73 74]
test_index [ 0 10 18 19 22 29 30 32 38 39 41 49 51 56 70]
train_index [ 0  1  2  3  4  5  7  8  9 13 14 16 17 18 20 21 22 23 24 25 26 27 28 29
 30 31 32 34 35 36 37 38 40 41 42 43 44 45 46 47 48 50 53 54 56 59 60 61
 63 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 6 10 11 12 15 19 33 39 49 51 52 55 57 58 62]
train_index [ 2  3  4  5  6  7 10 11 12 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
 30 31 32 33 34 36 37 39 40 42 43 45 46 47 48 49 50 51 52 53 55 56 57 58
 59 60 61 62 63 64 65 66 67 69 72 74]
test_index [ 0  1  8  9 13 14 35 38 41 44 54 68 70 71 73]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 18 19 20 26 27 28 29 32
 33 34 35 36 37 38 39 40 41 43 44 45 47 48 49 50 51 52 53 54 55 56 57 58
 59 60 62 63 65 66 68 69 70 71 73 74]
test_index [ 2 17 21 22 23 24 25 30 31 42 46 61 64 67 72]
train_index [ 0  1  2  4  6  7  8  9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27
 29 30 31 32 33 35 37 38 39 41 42 44 46 49 50 51 52 53 54 55 57 58 59 60
 61 62 63 64 67 68 69 70 71 72 73 74]
test_index [ 3  5 16 18 28 34 36 40 43 45 47 48 56 65 66]
train_index [ 0  1  2  3  5  6  8  9 10 11 12 13 14 15 16 17 18 19 21 22 23 24 25 28
 30 31 33 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 51 52 54 55 56 57
 58 61 62 64 65 66 67 68 70 71 72 73]
test_index [ 4  7 20 26 27 29 32 37 50 53 59 60 63 69 74]
train_index [ 0  1  3  4  5  7  8 11 12 13 14 15 16 18 19 20 21 22 23 24 25 26 27 28
 29 30 31 32 34 35 36 37 38 39 41 42 43 44 45 46 48 50 51 52 54 56 57 58
 59 60 62 63 64 65 66 67 69 70 73 74]
test_index [ 2  6  9 10 17 33 40 47 49 53 55 61 68 71 72]
train_index [ 2  3  4  5  6  7  9 10 12 13 14 15 16 17 18 19 21 24 25 27 29 31 32 33
 34 35 36 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 55 57 58 59 60
 61 62 63 64 65 66 67 68 69 70 71 72]
test_index [ 0  1  8 11 20 22 23 26 28 30 37 54 56 73 74]
train_index [ 0  1  2  5  6  7  8  9 10 11 13 14 15 17 19 20 21 22 23 24 26 28 30 31
 32 33 35 36 37 40 41 42 43 44 46 47 48 49 50 51 53 54 55 56 57 58 59 60
 61 62 63 64 65 67 68 70 71 72 73 74]
test_index [ 3  4 12 16 18 25 27 29 34 38 39 45 52 66 69]
train_index [ 0  1  2  3  4  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 33 34 35 37 38 39 40 44 45 47 49 50 52 53 54 55 56
 57 61 62 64 65 66 68 69 71 72 73 74]
test_index [ 5 32 36 41 42 43 46 48 51 58 59 60 63 67 70]
train_index [ 0  1  2  3  4  5  6  8  9 10 11 12 16 17 18 20 22 23 25 26 27 28 29 30
 32 33 34 36 37 38 39 40 41 42 43 45 46 47 48 49 51 52 53 54 55 56 58 59
 60 61 63 66 67 68 69 70 71 72 73 74]
test_index [ 7 13 14 15 19 21 24 31 35 44 50 57 62 64 65]
train_index [ 0  1  2  3  4  6  7  8  9 10 11 12 13 15 16 17 18 19 22 23 24 26 27 28
 30 31 32 33 34 35 36 37 38 39 43 44 45 46 47 48 51 52 53 54 55 56 57 59
 60 61 62 65 66 67 68 69 70 72 73 74]
test_index [ 5 14 20 21 25 29 40 41 42 49 50 58 63 64 71]
train_index [ 0  1  2  3  4  5  7  9 11 14 15 18 19 20 21 22 23 25 26 27 28 29 30 31
 32 33 34 35 36 37 38 39 40 41 42 44 46 47 48 49 50 51 52 53 55 56 57 58
 60 61 62 63 64 65 67 68 69 70 71 72]
test_index [ 6  8 10 12 13 16 17 24 43 45 54 59 66 73 74]
train_index [ 0  1  3  4  5  6  8  9 10 12 13 14 15 16 17 18 20 21 22 23 24 25 28 29
 30 31 32 33 35 38 40 41 42 43 44 45 46 47 48 49 50 51 53 54 56 57 58 59
 60 61 62 63 64 66 68 69 71 72 73 74]
test_index [ 2  7 11 19 26 27 34 36 37 39 52 55 65 67 70]
train_index [ 2  4  5  6  7  8  9 10 11 12 13 14 15 16 17 19 20 21 22 24 25 26 27 28
 29 32 34 36 37 38 39 40 41 42 43 45 46 47 49 50 52 53 54 55 56 57 58 59
 61 63 64 65 66 67 68 70 71 72 73 74]
test_index [ 0  1  3 18 23 30 31 33 35 44 48 51 60 62 69]
train_index [ 0  1  2  3  5  6  7  8 10 11 12 13 14 16 17 18 19 20 21 23 24 25 26 27
 29 30 31 33 34 35 36 37 39 40 41 42 43 44 45 48 49 50 51 52 54 55 58 59
 60 62 63 64 65 66 67 69 70 71 73 74]
test_index [ 4  9 15 22 28 32 38 46 47 53 56 57 61 68 72]
train_index [ 2  3  4  6  8  9 10 11 12 13 14 15 16 18 19 20 21 22 23 24 26 27 29 30
 32 33 34 35 36 37 38 39 40 42 44 45 46 47 48 49 50 51 53 54 56 59 60 61
 62 63 64 65 66 67 68 70 71 72 73 74]
test_index [ 0  1  5  7 17 25 28 31 41 43 52 55 57 58 69]
train_index [ 0  1  3  4  5  6  7  8 11 12 13 15 16 17 18 19 20 21 22 23 24 25 27 28
 29 30 31 32 34 35 36 40 41 43 44 45 47 48 50 52 53 54 55 56 57 58 59 60
 61 63 64 65 67 68 69 70 71 72 73 74]
test_index [ 2  9 10 14 26 33 37 38 39 42 46 49 51 62 66]
train_index [ 0  1  2  5  7  9 10 11 12 14 16 17 18 19 21 22 23 24 25 26 28 29 31 33
 34 35 36 37 38 39 40 41 42 43 46 47 48 49 50 51 52 54 55 56 57 58 59 61
 62 63 65 66 67 68 69 70 71 72 73 74]
test_index [ 3  4  6  8 13 15 20 27 30 32 44 45 53 60 64]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 17 20 22 23 24 25 26 27
 28 30 31 32 33 34 35 36 37 38 39 41 42 43 44 45 46 48 49 51 52 53 54 55
 57 58 60 61 62 63 64 66 68 69 72 73]
test_index [16 18 19 21 29 40 47 50 56 59 65 67 70 71 74]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 13 14 15 16 17 18 19 20 21 25 26 27 28
 29 30 31 32 33 37 38 39 40 41 42 43 44 45 46 47 49 50 51 52 53 55 56 57
 58 59 60 62 64 65 66 67 69 70 71 74]
test_index [11 12 22 23 24 34 35 36 48 54 61 63 68 72 73]
train_index [ 0  2  3  4  5  7  8  9 10 12 13 14 15 16 17 18 19 20 22 24 25 26 27 28
 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 46 47 48 49 51 52 53 57 58
 59 60 61 62 63 64 65 66 67 69 73 74]
test_index [ 1  6 11 21 23 29 45 50 54 55 56 68 70 71 72]
train_index [ 0  1  2  3  4  5  6  7  9 10 11 12 15 16 18 19 20 21 23 24 25 26 27 28
 29 30 31 32 34 35 36 37 38 39 40 43 44 45 46 48 49 50 51 52 53 54 55 56
 57 59 60 63 64 65 66 68 69 70 71 72]
test_index [ 8 13 14 17 22 33 41 42 47 58 61 62 67 73 74]
train_index [ 1  2  3  4  5  6  7  8  9 11 12 13 14 16 17 18 19 21 22 23 25 26 27 28
 29 30 33 35 36 37 38 41 42 43 44 45 47 48 50 53 54 55 56 57 58 59 60 61
 62 64 65 66 67 68 69 70 71 72 73 74]
test_index [ 0 10 15 20 24 31 32 34 39 40 46 49 51 52 63]
train_index [ 0  1  3  4  5  6  7  8 10 11 13 14 15 16 17 18 20 21 22 23 24 28 29 30
 31 32 33 34 35 36 37 39 40 41 42 44 45 46 47 49 50 51 52 54 55 56 58 59
 61 62 63 64 65 67 68 70 71 72 73 74]
test_index [ 2  9 12 19 25 26 27 38 43 48 53 57 60 66 69]
train_index [ 0  1  2  6  8  9 10 11 12 13 14 15 17 19 20 21 22 23 24 25 26 27 29 31
 32 33 34 38 39 40 41 42 43 45 46 47 48 49 50 51 52 53 54 55 56 57 58 60
 61 62 63 66 67 68 69 70 71 72 73 74]
test_index [ 3  4  5  7 16 18 28 30 35 36 37 44 59 64 65]
train_index [ 0  1  2  4  5  9 10 12 15 16 17 18 19 20 21 22 24 25 26 27 28 29 30 31
 32 33 34 36 38 39 40 41 42 44 45 46 47 48 49 50 51 52 54 55 56 57 58 59
 60 61 62 63 64 65 66 68 69 71 72 73]
test_index [ 3  6  7  8 11 13 14 23 35 37 43 53 67 70 74]
train_index [ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 20 21 22 23 24 25 26
 27 28 29 31 32 33 34 35 37 40 42 43 44 45 46 47 49 50 53 54 55 56 57 58
 59 60 61 62 63 65 67 68 69 70 72 74]
test_index [ 2 18 19 30 36 38 39 41 48 51 52 64 66 71 73]
train_index [ 0  1  2  3  4  5  6  7  8  9 11 12 13 14 16 17 18 19 23 24 26 27 28 29
 30 32 34 35 36 37 38 39 40 41 43 44 45 46 48 49 50 51 52 53 56 57 58 59
 60 62 63 64 65 66 67 70 71 72 73 74]
test_index [10 15 20 21 22 25 31 33 42 47 54 55 61 68 69]
train_index [ 2  3  6  7  8 10 11 12 13 14 15 16 18 19 20 21 22 23 25 26 27 28 30 31
 32 33 34 35 36 37 38 39 40 41 42 43 45 47 48 49 51 52 53 54 55 57 59 60
 61 62 63 64 66 67 68 69 70 71 73 74]
test_index [ 0  1  4  5  9 17 24 29 44 46 50 56 58 65 72]
train_index [ 0  1  2  3  4  5  6  7  8  9 10 11 13 14 15 17 18 19 20 21 22 23 24 25
 29 30 31 33 35 36 37 38 39 41 42 43 44 46 47 48 50 51 52 53 54 55 56 58
 61 64 65 66 67 68 69 70 71 72 73 74]
test_index [12 16 26 27 28 32 34 40 45 49 57 59 60 62 63]
from sklearn.model_selection import cross_validate
cross_validate(knn,X_train,y_train,cv=rkf,scoring="accuracy",return_estimator=True)
{'fit_time': array([0.00099969, 0.        , 0.00099897, 0.        , 0.        ,
        0.00100088, 0.00100112, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.00099134, 0.00101256, 0.00099635,
        0.        , 0.        , 0.        , 0.00099874, 0.        ,
        0.00105643, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.00100422,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ]),
 'score_time': array([0.00099945, 0.00100017, 0.        , 0.00099826, 0.0010016 ,
        0.00099826, 0.00112462, 0.00212598, 0.00103188, 0.00099683,
        0.0009737 , 0.00103641, 0.        , 0.        , 0.        ,
        0.00097394, 0.00102925, 0.00099778, 0.        , 0.00100136,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.00100565, 0.00099897, 0.        , 0.00099373, 0.00099897,
        0.00100088, 0.00106072, 0.00103712, 0.00107408, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.00101113, 0.0010767 , 0.00099373, 0.00093102]),
 'estimator': [KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6),
  KNeighborsClassifier(n_neighbors=6)],
 'test_score': array([1.        , 1.        , 1.        , 1.        , 0.93333333,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        0.93333333, 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 1.        ])}
#5 搜索一下什么样的邻居个数K是最好的,K的范围这里设置为1,10
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_validate
def getBestK(X_train,y_train,K):
    best_score=0
    best_k=1
#     X_train_set,X_val,y_train_set,y_val=train_test_split(X_train,y_train)
    rkf=RepeatedKFold(n_repeats=5,n_splits=5,random_state=42)
    for num in range(1,K):
        knn=KNeighborsClassifier(num)
        result=cross_validate(knn,X_train,y_train,cv=rkf,scoring="f1")
        score=result["test_score"].mean()
        score=round(score,2)
        print(score,num)
        if score>best_score:
            best_k=num
            best_score=score
    return best_k,best_score
best_k,best_score=getBestK(X_train,y_train,15)
best_k,best_score
0.98 1
0.99 2
0.99 3
0.99 4
0.99 5
0.99 6
1.0 7
0.99 8
0.99 9
0.98 10
0.98 11
0.97 12
0.98 13
0.97 14
(7, 1.0)
knn=KNeighborsClassifier(best_k)
knn.fit(X_train,y_train)
knn.score(X_test,y_test)
1.0

自动调参吧,试试循环,找到最优的k值

实验:试试用KNN完成回归任务

1 准备数据

import numpy as np
x1=np.linspace(-10,10,100)
x2=np.linspace(-5,15,100)
#手动构造一些数据
w1=5
w2=4
y=x1*w1+x2*w2
y
array([-70.        , -68.18181818, -66.36363636, -64.54545455,
       -62.72727273, -60.90909091, -59.09090909, -57.27272727,
       -55.45454545, -53.63636364, -51.81818182, -50.        ,
       -48.18181818, -46.36363636, -44.54545455, -42.72727273,
       -40.90909091, -39.09090909, -37.27272727, -35.45454545,
       -33.63636364, -31.81818182, -30.        , -28.18181818,
       -26.36363636, -24.54545455, -22.72727273, -20.90909091,
       -19.09090909, -17.27272727, -15.45454545, -13.63636364,
       -11.81818182, -10.        ,  -8.18181818,  -6.36363636,
        -4.54545455,  -2.72727273,  -0.90909091,   0.90909091,
         2.72727273,   4.54545455,   6.36363636,   8.18181818,
        10.        ,  11.81818182,  13.63636364,  15.45454545,
        17.27272727,  19.09090909,  20.90909091,  22.72727273,
        24.54545455,  26.36363636,  28.18181818,  30.        ,
        31.81818182,  33.63636364,  35.45454545,  37.27272727,
        39.09090909,  40.90909091,  42.72727273,  44.54545455,
        46.36363636,  48.18181818,  50.        ,  51.81818182,
        53.63636364,  55.45454545,  57.27272727,  59.09090909,
        60.90909091,  62.72727273,  64.54545455,  66.36363636,
        68.18181818,  70.        ,  71.81818182,  73.63636364,
        75.45454545,  77.27272727,  79.09090909,  80.90909091,
        82.72727273,  84.54545455,  86.36363636,  88.18181818,
        90.        ,  91.81818182,  93.63636364,  95.45454545,
        97.27272727,  99.09090909, 100.90909091, 102.72727273,
       104.54545455, 106.36363636, 108.18181818, 110.        ])
x1=x1.reshape(len(x1),1)
x2=x2.reshape(len(x2),1)
y=y.reshape(len(y),1)
import pandas as pd
data=np.hstack([x1,x2,y])
# 给数据加点噪声
np.random.seed=10
data=data+np.random.normal(0.1,1,[100,3])
data
array([[-9.80997918e+00, -4.47671228e+00, -6.86113562e+01],
       [-9.07863100e+00, -3.29030887e+00, -6.75412089e+01],
       [-8.17535392e+00, -4.85515660e+00, -6.56682184e+01],
       [-9.33603110e+00, -4.67304042e+00, -6.39943055e+01],
       [-8.31454149e+00, -3.61401814e+00, -6.15552168e+01],
       [-9.35462761e+00, -3.99216837e+00, -6.16450829e+01],
       [-7.35641032e+00, -5.10713257e+00, -5.80574405e+01],
       [-7.75808720e+00, -2.81374154e+00, -5.72785817e+01],
       [-7.85420726e+00, -3.25192460e+00, -5.58260703e+01],
       [-7.79785201e+00, -4.59268755e+00, -5.46208629e+01],
       [-9.90411101e+00, -7.55985286e-01, -5.19239440e+01],
       [-4.91167456e+00, -1.48242138e+00, -5.06778041e+01],
       [-9.25608953e+00, -1.12391146e+00, -4.80701720e+01],
       [-6.92987717e+00, -3.58106474e+00, -4.58459514e+01],
       [-7.19890084e+00, -2.10260074e+00, -4.46497119e+01],
       [-8.56812108e+00, -2.45314063e+00, -4.19130070e+01],
       [-6.97527315e+00, -3.25615055e+00, -4.15373469e+01],
       [-6.09201512e+00, -1.07060626e+00, -4.05034362e+01],
       [-5.94248008e+00,  6.42232477e-01, -3.64281226e+01],
       [-5.99567467e+00, -2.26531046e+00, -3.32873129e+01],
       [-7.56906953e+00, -6.81005515e-01, -3.42368449e+01],
       [-6.54272630e+00, -7.32829423e-01, -3.18556358e+01],
       [-4.68241322e+00, -1.55653397e+00, -2.99105801e+01],
       [-5.61148642e+00, -1.96269845e+00, -2.80144819e+01],
       [-4.64818297e+00,  2.21684956e-01, -2.56420739e+01],
       [-5.64237828e+00, -5.05215614e-02, -2.44150985e+01],
       [-4.77269716e+00,  3.12543954e-01, -2.35962190e+01],
       [-3.93579614e+00,  3.14368041e-01, -2.04078436e+01],
       [-4.67599369e+00,  1.38646098e+00, -1.95569688e+01],
       [-4.56613680e+00,  2.18761537e-01, -1.76443732e+01],
       [-4.12462083e+00,  7.81731566e-01, -1.55500903e+01],
       [-5.00893448e+00,  8.43167883e-01, -1.37904298e+01],
       [-3.32575389e+00,  8.87284515e-01, -1.16870554e+01],
       [-4.60962500e+00,  2.47674165e+00, -9.43497025e+00],
       [-2.55399230e+00,  1.60304976e+00, -7.30116575e+00],
       [-3.92552974e+00,  2.02861216e+00, -8.47211685e+00],
       [-2.85445054e+00,  1.32252697e+00, -2.27221086e+00],
       [-3.20383909e+00,  1.56885433e+00, -1.46024067e+00],
       [-1.87732669e+00,  1.18972183e+00, -1.68276177e+00],
       [-1.35842429e+00,  3.76086938e+00,  3.35135047e-01],
       [-7.24957523e-01,  4.37716480e+00,  1.17352349e+00],
       [-3.70453016e+00,  5.08438460e+00,  3.35207490e+00],
       [-7.97872551e-01,  2.78241431e+00,  5.09073378e+00],
       [-3.08232423e+00,  4.21925884e+00,  7.90719675e+00],
       [ 5.28844300e-01,  4.16412164e+00,  1.01885052e+01],
       [-2.64895900e-02,  4.04451188e+00,  1.32964325e+01],
       [ 7.67644414e-01,  4.38295411e+00,  1.20330676e+01],
       [-3.17298624e-01,  5.52193479e+00,  1.44587349e+01],
       [-4.05576007e-01,  6.15916945e+00,  1.77192591e+01],
       [ 2.58635850e-01,  4.36652636e+00,  2.08469868e+01],
       [-1.15875757e+00,  5.86049204e+00,  2.12312972e+01],
       [-7.16862753e-01,  7.60609045e+00,  2.24464377e+01],
       [ 1.00827677e+00,  7.13593566e+00,  2.60236434e+01],
       [ 8.64304920e-01,  7.70071685e+00,  2.67335947e+01],
       [ 3.14401551e+00,  5.74841619e+00,  2.76627520e+01],
       [-1.18085370e-02,  5.45967297e+00,  3.01731518e+01],
       [ 9.67211352e-01,  6.30044676e+00,  3.31847137e+01],
       [ 1.32254229e+00,  6.51216091e+00,  3.31636096e+01],
       [ 9.66206984e-01,  8.15352634e+00,  3.54552668e+01],
       [ 1.50374715e+00,  8.38063421e+00,  3.82675089e+01],
       [ 1.20333031e+00,  8.30155252e+00,  4.05759780e+01],
       [ 2.84702572e+00,  7.44997601e+00,  4.16313092e+01],
       [ 2.82319554e+00,  7.03396275e+00,  4.33733979e+01],
       [ 3.88755763e+00,  9.63373825e+00,  4.63550733e+01],
       [ 3.31979805e+00,  1.00825563e+01,  4.66602506e+01],
       [ 3.67714879e+00,  8.98817386e+00,  4.71815191e+01],
       [ 5.61673924e+00,  8.83321195e+00,  4.90218726e+01],
       [ 4.64376606e+00,  1.05003123e+01,  5.16821640e+01],
       [ 3.38312917e+00,  9.93985678e+00,  5.44523927e+01],
       [ 2.90435391e+00,  8.76211593e+00,  5.72974806e+01],
       [ 1.94362594e+00,  8.37086325e+00,  5.69748221e+01],
       [ 4.86357671e+00,  8.79920772e+00,  5.92178403e+01],
       [ 5.21731274e+00,  8.76064972e+00,  6.30249467e+01],
       [ 5.86040809e+00,  1.12868041e+01,  6.26973140e+01],
       [ 4.05985223e+00,  8.65847315e+00,  6.61012727e+01],
       [ 6.19899121e+00,  8.30649111e+00,  6.37680817e+01],
       [ 5.73989925e+00,  1.00161474e+01,  6.92336558e+01],
       [ 5.38266361e+00,  1.03971821e+01,  7.17084241e+01],
       [ 7.23264561e+00,  1.20494918e+01,  7.05362027e+01],
       [ 6.11948179e+00,  1.19855375e+01,  7.55318286e+01],
       [ 8.03847795e+00,  9.79749582e+00,  7.47950707e+01],
       [ 8.30070319e+00,  1.07233637e+01,  7.93806649e+01],
       [ 7.44456666e+00,  1.11936713e+01,  7.84042566e+01],
       [ 6.87035796e+00,  1.23168763e+01,  8.01532295e+01],
       [ 6.57153443e+00,  1.12686434e+01,  8.32735790e+01],
       [ 8.06216701e+00,  1.26805930e+01,  8.58973008e+01],
       [ 8.75001919e+00,  1.36698902e+01,  8.72099703e+01],
       [ 7.30252179e+00,  1.34260600e+01,  8.71816534e+01],
       [ 1.02174549e+01,  1.12734356e+01,  9.06574864e+01],
       [ 9.16397441e+00,  1.35946035e+01,  9.12502949e+01],
       [ 7.65119402e+00,  1.26062408e+01,  9.37067133e+01],
       [ 7.88012441e+00,  1.20190767e+01,  9.49682650e+01],
       [ 8.32044954e+00,  1.32807945e+01,  9.65808990e+01],
       [ 8.01089317e+00,  1.64722621e+01,  9.82354518e+01],
       [ 9.02271142e+00,  1.33190747e+01,  1.00825525e+02],
       [ 8.09970303e+00,  1.46680917e+01,  1.03017581e+02],
       [ 1.13875348e+01,  1.46989516e+01,  1.04003935e+02],
       [ 1.01333057e+01,  1.33257429e+01,  1.05931984e+02],
       [ 9.38629399e+00,  1.39040038e+01,  1.10363757e+02],
       [ 1.13412247e+01,  1.61090392e+01,  1.10731822e+02]])
#将数据拆分成训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(data[:,:2],data[:,-1])
X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))

2 通过K个近邻预测的标签的距离来预测当前样本的标签

#改写函数
#返回所有近邻的标签的均值作为当前x的预测值
def calcu_distance_return(x,X_train,y_train):
    KNN_x=[]
    #遍历训练集中的每个样本
    for i in range(X_train.shape[0]):
        if len(KNN_x)<K:
            KNN_x.append((euclidean(x,X_train[i]),y_train[i]))
        else:
            KNN_x.sort()
            for j in range(K): 
                if (euclidean(x,X_train[i]))< KNN_x[j][0]:
                    KNN_x[j]=(euclidean(x,X_train[i]),y_train[i])
                    break
    knn_label=[item[1] for item in KNN_x]           
    return np.mean(knn_label)
#对整个测试集进行预测
def predict(X_test):
    y_pred=np.zeros(X_test.shape[0])
    for i in range(X_test.shape[0]):
        y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) 
        y_pred[i]=y_hat_i
    return y_pred
#输出预测结果
y_pred= predict(X_test)
y_pred
array([-48.77391118, -61.82953142,  -7.08681066,  31.79119171,
        89.89605669,  49.28413251,  52.97713079,  33.48545677,
        63.32131747,  98.05154212, -55.78008004,  98.04210317,
         7.02443886, -19.02562562,  11.49285143, -13.67585848,
        52.97713079,  21.82629113,  10.45687568,  55.14568247,
        -9.552268  ,  94.91846026, -11.51277047,  22.35944142,
        86.13169115])
y_test
array([-41.53734685, -58.05744051,  -1.46024067,  40.57597798,
       103.01758072,  66.10127272,  46.66025056,  56.97482206,
        63.0249467 , 100.8255246 , -54.62086294,  91.25029492,
         3.3520749 , -23.59621905,   1.17352349, -20.40784363,
        46.35507328,  21.23129715,   5.09073378,  59.21784029,
         7.90719675,  98.23545178,  -1.68276177,  17.71925914,
        78.40425661])

3 通过R方进行评估

from sklearn.metrics import r2_score
r2_score(y_test,y_pred)
0.9634297760055799


目录
相关文章
|
5天前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
111 55
|
23天前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
73 4
|
21天前
|
搜索推荐 Python
利用Python内置函数实现的冒泡排序算法
在上述代码中,`bubble_sort` 函数接受一个列表 `arr` 作为输入。通过两层循环,外层循环控制排序的轮数,内层循环用于比较相邻的元素并进行交换。如果前一个元素大于后一个元素,就将它们交换位置。
124 67
|
21天前
|
存储 搜索推荐 Python
用 Python 实现快速排序算法。
快速排序的平均时间复杂度为$O(nlogn)$,空间复杂度为$O(logn)$。它在大多数情况下表现良好,但在某些特殊情况下可能会退化为最坏情况,时间复杂度为$O(n^2)$。你可以根据实际需求对代码进行调整和修改,或者尝试使用其他优化策略来提高快速排序的性能
115 61
|
22天前
|
算法 数据安全/隐私保护 开发者
马特赛特旋转算法:Python的随机模块背后的力量
马特赛特旋转算法是Python `random`模块的核心,由松本真和西村拓士于1997年提出。它基于线性反馈移位寄存器,具有超长周期和高维均匀性,适用于模拟、密码学等领域。Python中通过设置种子值初始化状态数组,经状态更新和输出提取生成随机数,代码简单高效。
104 63
|
15天前
|
机器学习/深度学习 人工智能 算法
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
宠物识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了37种常见的猫狗宠物种类数据集【'阿比西尼亚猫(Abyssinian)', '孟加拉猫(Bengal)', '暹罗猫(Birman)', '孟买猫(Bombay)', '英国短毛猫(British Shorthair)', '埃及猫(Egyptian Mau)', '缅因猫(Maine Coon)', '波斯猫(Persian)', '布偶猫(Ragdoll)', '俄罗斯蓝猫(Russian Blue)', '暹罗猫(Siamese)', '斯芬克斯猫(Sphynx)', '美国斗牛犬
93 29
【宠物识别系统】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+图像识别
|
21天前
|
存储 算法 搜索推荐
Python 中数据结构和算法的关系
数据结构是算法的载体,算法是对数据结构的操作和运用。它们共同构成了计算机程序的核心,对于提高程序的质量和性能具有至关重要的作用
|
21天前
|
数据采集 存储 算法
Python 中的数据结构和算法优化策略
Python中的数据结构和算法如何进行优化?
|
20天前
|
机器学习/深度学习 算法 数据挖掘
C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出
本文探讨了C语言在机器学习中的应用及其重要性。C语言以其高效性、灵活性和可移植性,适合开发高性能的机器学习算法,尤其在底层算法实现、嵌入式系统和高性能计算中表现突出。文章还介绍了C语言在知名机器学习库中的作用,以及与Python等语言结合使用的案例,展望了其未来发展的挑战与机遇。
39 1