KNN算法
1.k kk近邻法是基本且简单的分类与回归方法。k kk近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的k kk个最近邻训练实例点,然后利用这k kk个训练实例点的类的多数来预测输入实例点的类。
2.k kk近邻模型对应于基于训练数据集对特征空间的一个划分。k kk近邻法中,当训练集、距离度量、k kk值及分类决策规则确定后,其结果唯一确定,没有近似,他没有学习参数。
3.k kk近邻法三要素:距离度量、k kk值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。k kk值小时,k kk近邻模型更复杂;k kk值大时,k kk近邻模型更简单。k kk值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k kk。
常用的分类决策规则是多数表决,对应于经验风险最小化。
4.k kk近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对k kk维空间的一个划分,其每个结点对应于k kk维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。
前言 距离度量
在机器学习算法中,我们经常需要计算样本之间的相似度,通常的做法是计算样本之间的距离。
设x xx和y yy为两个向量,求它们之间的距离。
这里用Numpy实现,设和为ndarray <numpy.ndarray>,它们的shape都是(N,)
d dd为所求的距离,是个浮点数(float)。
1) 欧式距离
欧几里得度量(euclidean metric)(也称欧氏距离)是一个通常采用的距离定义,指在m维空间中两个点之间的真实距离,或者向量的自然长度(即该点到原点的距离)。在二维和三维空间中的欧氏距离就是两点之间的实际距离。
距离公式:
代码实现:
def euclidean(x, y): return np.sqrt(np.sum((x - y)**2))
(2) 曼哈顿距离(Manhattan distance)
想象你在城市道路里,要从一个十字路口开车到另外一个十字路口,驾驶距离是两点间的直线距离吗?显然不是,除非你能穿越大楼。实际驾驶距离就是这个“曼哈顿距离”。而这也是曼哈顿距离名称的来源,曼哈顿距离也称为城市街区距离(City Block distance)。
距离公式:
代码实现:
def manhatan_distance(x,y): return np.sum(np.abs(x-y))
(3) 切比雪夫距离(Chebyshev distance)
在数学中,切比雪夫距离(Chebyshev distance)或是L∞度量,是向量空间中的一种度量,二个点之间的距离定义是其各坐标数值差绝对值的最大值。以数学的观点来看,切比雪夫距离是由一致范数(uniform norm)(或称为上确界范数)所衍生的度量,也是超凸度量(injective metric space)的一种。
距离公式:
若将国际象棋棋盘放在二维直角座标系中,格子的边长定义为1,座标的x xx轴及y yy轴和棋盘方格平行,原点恰落在某一格的中心点,则王从一个位置走到其他位置需要的步数恰为二个位置的切比雪夫距离,因此切比雪夫距离也称为棋盘距离。例如位置F6和位置E2的切比雪夫距离为4。任何一个不在棋盘边缘的位置,和周围八个位置的切比雪夫距离都是1。
代码实现:
def chebysev_distance(x,y): return np.max(np.abs(x-y))
(4) 闵可夫斯基距离(Minkowski distance)
闵氏空间指狭义相对论中由一个时间维和三个空间维组成的时空,为俄裔德国数学家闵可夫斯基(H.Minkowski,1864-1909)最先表述。他的平坦空间(即假设没有重力,曲率为零的空间)的概念以及表示为特殊距离量的几何学是与狭义相对论的要求相一致的。闵可夫斯基空间不同于牛顿力学的平坦空间。p pp取1或2时的闵氏距离是最为常用的,p = 2 p= 2p=2即为欧氏距离,而p = 1 p =1p=1时则为曼哈顿距离。
当p pp取无穷时的极限情况下,可以得到切比雪夫距离。
距离公式:
代码实现:
def minkowski(x, y, p): return np.sum(np.abs(x - y)**p)**(1 / p)
(5) 汉明距离(Hamming distance)
汉明距离是使用在数据传输差错控制编码里面的,汉明距离是一个概念,它表示两个(相同长度)字对应位不同的数量,我们以表示两个字,之间的汉明距离。对两个字符串进行异或运算,并统计结果为1的个数,那么这个数就是汉明距离。
距离公式:
def hamming(x,y): return np.sum(x!=y)/len(x)
(6) 余弦相似度(Cosine Similarity)
余弦相似性通过测量两个向量的夹角的余弦值来度量它们之间的相似性。0度角的余弦值是1,而其他任何角度的余弦值都不大于1;并且其最小值是-1。从而两个向量之间的角度的余弦值确定两个向量是否大致指向相同的方向。两个向量有相同的指向时,余弦相似度的值为1;两个向量夹角为90°时,余弦相似度的值为0;两个向量指向完全相反的方向时,余弦相似度的值为-1。这结果是与向量的长度无关的,仅仅与向量的指向方向相关。余弦相似度通常用于正空间,因此给出的值为0到1之间。
二维空间为例,上图的a aa和b bb是两个向量,我们要计算它们的夹角θ。余弦定理告诉我们,可以用下面的公式求得:
代码实现:
def square_rooted(x): return np.sqrt(np.sum(np.power(x,2)))
def cosine_similarity_distance(x,y): fenzi=np.sum(np.multiply(x,y)) fenmu=square_rooted(x)*square_rooted(y) return fenzi/fenmu
import numpy as np print(cosine_similarity_distance([3, 45, 7, 2], [2, 54, 13, 15]))
0.9722842517123499
KNN算法介绍
1.k kk近邻法是基本且简单的分类与回归方法。k kk近邻法的基本做法是:对给定的训练实例点和输入实例点,首先确定输入实例点的k kk个最近邻训练实例点,然后利用这k kk个训练实例点的类的多数来预测输入实例点的类。
2.k kk近邻模型对应于基于训练数据集对特征空间的一个划分。k kk近邻法中,当训练集、距离度量、k kk值及分类决策规则确定后,其结果唯一确定。
3.k kk近邻法三要素:距离度量、k kk值的选择和分类决策规则。常用的距离度量是欧氏距离。k kk值小时,k kk近邻模型更复杂;k kk值大时,k kk近邻模型更简单。k kk值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k kk。
常用的分类决策规则是多数表决,对应于经验风险最小化。
4.k kk近邻法的实现需要考虑如何快速搜索k个最近邻点。kd树是一种便于对k维空间中的数据进行快速检索的数据结构。kd树是二叉树,表示对k kk维空间的一个划分,其每个结点对应于k kk维空间划分中的一个超矩形区域。利用kd树可以省去对大部分数据点的搜索, 从而减少搜索的计算量。
python实现,遍历所有数据点,找出n nn个距离最近的点的分类情况,少数服从多数
1 数据的准备
import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counter
导入鸢尾花数据集
iris = load_iris() iris
{'data': array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'frame': None, 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n \n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 'filename': 'iris.csv', 'data_module': 'sklearn.datasets.data'}
iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df["target"]=iris.target df.columns=iris.feature_names+["target"] df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
df.head()
sepal length (cm) |
sepal width (cm) | petal length (cm) | petal width (cm) | target | |
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
选择长和宽的数据进行可视化
#选取前100行数据进行可视化 plt.figure(figsize=(12, 8)) plt.scatter(df[:50]["sepal length (cm)"], df[:50]["sepal width (cm)"], label='0') plt.scatter(df[50:100]["sepal length (cm)"], df[50:100]["sepal width (cm)"], label='1') plt.xlabel('sepal length', fontsize=18) plt.ylabel('sepal width', fontsize=18) plt.legend() plt.show()
2 划分训练数据和测试数据
from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(df.iloc[:100,:2].values,df.iloc[:100,-1].values) X_train.shape,X_test.shape,y_train.shape,y_test.shape
((75, 2), (25, 2), (75,), (25,))
X_train,y_train
(array([[5. , 3.3], [4.6, 3.4], [5.2, 4.1], [5.7, 2.8], [5.1, 3.4], [4.8, 3. ], [5.9, 3.2], [5.7, 3.8], [4.8, 3.4], [5.3, 3.7], [5.1, 3.8], [5.5, 2.4], [6. , 2.2], [5.5, 4.2], [5.5, 2.6], [5.4, 3.4], [4.4, 2.9], [6. , 2.9], [5.8, 2.7], [4.4, 3.2], [5.6, 2.9], [5.8, 2.7], [6.7, 3.1], [6. , 2.7], [5.7, 2.9], [4.6, 3.2], [4.9, 3.1], [7. , 3.2], [4.7, 3.2], [5.1, 2.5], [6.3, 2.3], [4.6, 3.1], [6.4, 3.2], [6.6, 3. ], [4.6, 3.6], [5.5, 2.4], [5.6, 3. ], [5.1, 3.7], [6.1, 2.8], [5.6, 2.7], [4.8, 3.1], [4.8, 3. ], [5. , 3.5], [6.2, 2.2], [6. , 3.4], [5.1, 3.3], [5.4, 3.9], [5.7, 2.6], [6.7, 3.1], [4.5, 2.3], [4.8, 3.4], [4.9, 2.4], [5.8, 4. ], [5. , 3. ], [6.6, 2.9], [6.1, 2.9], [5. , 3.5], [6.8, 2.8], [5. , 2.3], [5.4, 3. ], [4.3, 3. ], [4.9, 3.1], [4.9, 3. ], [5.1, 3.8], [5.1, 3.5], [5.5, 2.5], [5. , 3.6], [5. , 3.4], [5.4, 3.9], [5.1, 3.8], [5.1, 3.5], [5.2, 3.5], [5.8, 2.6], [6.4, 2.9], [6.1, 2.8]]), array([0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]))
3 通过K个近邻预测的标签的距离来预测当前样本的标签
#定义邻居数量 from collections import Counter K=3
KNN_x=[]
for i in range(X_train.shape[0]): if len(KNN_x)<K: KNN_x.append((euclidean(X_test[0],X_train[i]),y_train[i]))
KNN_x
[(0.6324555320336757, 0), (0.9219544457292889, 0), (1.3999999999999995, 0)]
count=Counter([item[1] for item in KNN_x]) count
Counter({0: 3})
count.items()
dict_items([(0, 3)])
sorted(count.items(),key=lambda x:x[1])[-1][0]
0
#返回任意一个样本x的标签 def calcu_distance_return(x,X_train,y_train): KNN_x=[] #遍历训练集中的每个样本 for i in range(X_train.shape[0]): if len(KNN_x)<K: KNN_x.append((euclidean(x,X_train[i]),y_train[i])) else: KNN_x.sort() for j in range(K): if (euclidean(x,X_train[i]))< KNN_x[j][0]: KNN_x[j]=(euclidean(x,X_train[i]),y_train[i]) break knn_label=[item[1] for item in KNN_x] counter_knn=Counter(knn_label) return sorted(counter_knn.items(),key=lambda item:item[1])[-1][0]
#对整个测试集进行预测 def predict(X_test): y_pred=np.zeros(X_test.shape[0]) for i in range(X_test.shape[0]): y_hat_i=calcu_distance_return(X_test[i],X_train,y_train) y_pred[i]=y_hat_i return y_pred
4 计算准确率
#输出预测结果 y_pred= predict(X_test).astype("int32")
y_pred
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0])
y_test=y_test.astype("int32") y_test
array([1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0])
#计算准确率 np.sum(y_pred==y_test)/X_test.shape[0]
1.0