K-近邻算法
以鸢尾花数据为例,以python为工具,对其该数据集进行分类。
在python的机器学习库sklearn中,
k-近邻模型默认使用近邻数为k=5。
当k值过小时容易产生过拟合,
当k值过大时容易产生欠拟合。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载鸢尾花数据集
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)
# 创建估计器
# 模型的近邻数默认为5,也可以设置不同的近邻数
model = KNeighborsClassifier()
# model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)
# 训练集准确率
train_score = model.score(X_train, y_train)
# 测试集准确率
test_score = model.score(X_test, y_test)
print("train score:", train_score)
print("test score:", test_score)
输出预测结果
y_pred = model.predict(X_test)
print(y_pred)
参考:
<Python机器学习实战 吕云翔>