Iris

简介: 【9月更文挑战第10天】

鸢尾花(Iris)是一种常见的花卉,而鸢尾花数据集(Iris dataset)是机器学习领域中一个非常著名的多变量数据集,被广泛用于测试分类算法[^11]。这个数据集包含了150个样本,分为3类鸢尾花,每类各50个样本。每个样本都有4个特征,分别是花萼长度、花萼宽度、花瓣长度和花瓣宽度,这些特征都是连续数值型数据。这些数据通常用于训练分类模型,以预测鸢尾花的种类[^11]。

KNN(K-Nearest Neighbors)是一种简单而有效的分类算法,它基于最近邻的概念来预测新样本的类别。KNN算法的核心思想是:一个样本的类别可以由其最近邻的样本的类别决定。在KNN分类中,我们找到与新样本特征最相似的K个训练样本(即最近邻),然后根据这些最近邻样本的已知类别来预测新样本的类别,通常是通过多数投票的方式来决定[^12]。

在鸢尾花数据集上应用KNN算法,通常能够取得很好的分类效果。通过计算新样本与数据集中每个样本之间的距离(例如欧氏距离),选择距离最近的K个样本,然后根据这些样本的类别来预测新样本的类别。这个过程不需要训练模型,因此KNN也被称为“懒惰学习”算法,它在预测时才真正开始“学习”[^12]。

KNN算法的关键在于如何选择适当的K值,这通常需要通过交叉验证来确定。此外,KNN算法的效果也受到数据特征尺度的影响,因此在实际应用中,通常需要进行特征缩放,以确保每个特征对距离的计算有相同的影响力[^14]。鸢尾花数据集因其简单和易于理解的特性,成为机器学习入门者学习和测试KNN算法的理想选择[^11]。
K-Nearest Neighbors (KNN) 是一种简单而强大的分类算法,它根据一个样本最近邻的标签来预测该样本的标签。下面是一个使用 Python 语言和 scikit-learn 库实现 KNN 分类器的基本示例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建 KNN 分类器实例,设置 K=3
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

# 预测测试集
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')

这段代码首先加载了鸢尾花(Iris)数据集,这是一个常用的分类数据集。然后,它将数据集分为训练集和测试集。接着,创建了一个 KNeighborsClassifier 实例,其中 n_neighbors 参数设置为 3,这意味着分类决策将基于最近的 3 个邻居。之后,使用训练集数据训练模型,并在测试集上进行预测。最后,计算并打印出模型的准确率。

如果你想要从头开始实现 KNN 算法,不使用 scikit-learn 库,下面是一个基本的 KNN 分类器的实现示例:

import numpy as np

class KNN:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X_train = X
        self.y_train = y

    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)

    def _predict(self, x):
        # 计算距离
        distances = np.sqrt((self.X_train - x) ** 2).sum(axis=1)
        # 获取最近的 K 个邻居的索引
        k_indices = distances.argsort()[:self.k]
        # 获取这些邻居的标签
        k_nearest_labels = self.y_train[k_indices]
        # 投票获取最常见的标签
        most_common = np.argmax(np.bincount(k_nearest_labels))
        return most_common

# 示例使用
if __name__ == "__main__":
    # 假设有一些数据和标签
    X_train = np.array([[1, 2], [2, 3], [3, 1], [6, 5], [7, 7], [8, 6]])
    y_train = np.array([0, 0, 0, 1, 1, 1])
    X_test = np.array([[0, 0], [5, 5]])

    knn = KNN(k=3)
    knn.fit(X_train, y_train)
    predictions = knn.predict(X_test)
    print(predictions)
目录
相关文章
|
11天前
|
算法 Linux
跟着Iris案例学Seaborn之Histplot
跟着Iris案例学Seaborn之Histplot
33 0
|
11天前
|
数据采集 数据挖掘 Linux
跟着Titanic案例学Seaborn之Barplot
跟着Titanic案例学Seaborn之Barplot
23 0
|
11天前
|
数据可视化 Python
跟着Titanic案例学Seaborn之Countplot
跟着Titanic案例学Seaborn之Countplot
26 0
|
11天前
|
数据处理 索引 Python
深入了解pandas中的loc和iloc
深入了解pandas中的loc和iloc
10 0
|
5月前
波士顿房价数据集 Boston house prices dataset
波士顿房价数据集 Boston house prices dataset
131 2
|
5月前
|
数据处理
iris数据集数据处理
iris数据集数据处理
|
5月前
|
存储 数据可视化 PyTorch
PyTorch中 Datasets & DataLoader 的介绍
PyTorch中 Datasets & DataLoader 的介绍
125 0
|
12月前
245Echarts - 3D 散点图(Scatter3D - Simplex Noise)
245Echarts - 3D 散点图(Scatter3D - Simplex Noise)
55 0
|
12月前
242Echarts - 3D 散点图(3D Scatter with Dataset)
242Echarts - 3D 散点图(3D Scatter with Dataset)
119 0
|
Python
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo
256 0
解决ImportError: umap.plot requires pandas matplotlib datashader bokeh holoviews scikit-image and colo