基于 sklearn 的鸢尾花分类

简介: 基于 sklearn 的鸢尾花分类

基于 sklearn 的鸢尾花分类


预测两个鸢尾花种属性:萼片宽度和萼片长度。


1. 导入鸢尾花数据集


  • 导入sklearn自带的iris数据集。
  • 一个data数组, ,其中,对于每个实例,我们都有萼片长度,萼片宽度,花瓣长度和花瓣宽度的实际值(请注意,出于效率原因,scikit-learn 方法使用了 NumPy ndarrays,而不是更具描述性但效率更低的 Python 词典或列表。这个数组的形状是(150, 4),这意味着我们有 150 行(每个实例一个)和四列(每个特征一个)。
  • 一个target数组,值在 0 到 2 的范围内,对应于每种鸢尾种类(0:山鸢尾,1:杂色鸢尾和 2:弗吉尼亚鸢尾),您可以通过打印iris.target.target_names值来验证。
from sklearn import datasets
iris = datasets.load_iris()
X_iris, y_iris = iris.data, iris.target
print(X_iris.shape, y_iris.shape)
print(X_iris[0], y_iris[0])
(150, 4) (150,)
[5.1 3.5 1.4 0.2] 0


2.划分数据集


train_test_split函数自动构建训练和评估数据集,随机选择样本。

from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
X, y = X_iris[:, :2], y_iris
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=33)
print(X_train.shape, y_train.shape)
scaler = preprocessing.StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
(112, 2) (112,)

最后三行在通常称为特征缩放的过程中修改训练集。对于每个特征, 计算平均值,从特征值中减去平均值,并将结果除以它们的标准差。缩放后,每个特征的平均值为零,标准差为 1。这种值的标准化(不会改变它们的分布,因为你可以通过在缩放之前和之后绘制X值来验证)是机器学习方法的常见要求,来避免具有大值的特征可能在权重上过重。


3.数据可视化


scatter函数简单地绘制每个实例的第一个特征值(萼片宽度)与其第二个特征值(萼片长度),并使用目标类值为每个类指定不同的颜色。

import matplotlib.pyplot as plt
colors = ['red', 'greenyellow', 'blue']
for i in range(len(colors)):
    xs = X_train[:, 0][y_train == i]
    ys = X_train[:, 1][y_train == i]
    plt.scatter(xs, ys, c=colors[i])
plt.legend(iris.target_names)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
Text(0, 0.5, 'Sepal width')

image.png


4.模型训练


使用 scikit-learn 中的SGDClassifier。 SGD 代表随机梯度下降,这是一种非常流行的数值过程,用于查找函数的局部最小值(在本例中为损失函数),测量每个实例离我们边界的距离。该算法将通过最小化损失函数来学习超平面的系数。

from sklearn.linear_model import SGDClassifier
clf = SGDClassifier()
clf.fit(X_train, y_train) 
>>> x_min, x_max = X_train[:, 0].min() - .5, X_train[:, 0].max() +  
    .5
>>> y_min, y_max = X_train[:, 1].min() - .5, X_train[:, 1].max() + 
    .5
>>> xs = np.arange(x_min, x_max, 0.5)
>>> fig, axes = plt.subplots(1, 3)
>>> fig.set_size_inches(10, 6)
>>> for i in [0, 1, 2]:
>>>     axes[i].set_aspect('equal')
>>>     axes[i].set_title('Class '+ str(i) + ' versus the rest')
>>>     axes[i].set_xlabel('Sepal length')
>>>     axes[i].set_ylabel('Sepal width')
>>>     axes[i].set_xlim(x_min, x_max)
>>>     axes[i].set_ylim(y_min, y_max)
>>>     sca(axes[i])
>>>     plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train,
        cmap=plt.cm.prism)
>>>     ys = (-clf.intercept_[i] –
        Xs * clf.coef_[i, 0]) / clf.coef_[i, 1]
>>>     plt.plot(xs, ys, hold=True)    

image.png


5.预测分类


>>>print clf.predict(scaler.transform([[4.7, 3.1]]))

[0]

如果我们的分类器是正确的,这个鸢尾花是一个山鸢尾。可能你已经注意到我们正在从可能的三个类中预测一个类,但是线性模型本质上是二元的:缺少某些东西。你是对的。我们的预测程序结合了三个二元分类器的结果,并选择了更有置信度的类。在这种情况下,我们将选择与实例的距离更长的边界线。我们可以使用分类器decision_function方法检查:

>>>print clf.decision_function(scaler.transform([[4.7, 3.1]]))
[[ 19.73905808   8.13288449 -28.63499119]]


6.评估结果


给定分类器和评估数据集,它测量由分类器正确分类的实例的比例。

>>> from sklearn import metrics
>>> y_train_pred = clf.predict(X_train)
>>> print metrics.accuracy_score(y_train, y_train_pred)
0.821428571429 
>>> y_pred = clf.predict(X_test)
>>> print metrics.accuracy_score(y_test, y_pred)
0.684210526316 

在 scikit-learn 中,有几个评估函数;我们将展示三种流行的:精确率,召回率和 F1 得分(或 F 度量)。他们假设二元分类问题和两个类 - 正面和负面。在我们的例子中,正类可以是山鸢尾,而其他两个将合并为一个负类。

  • 精确率: 计算预测为正例的实例中,正确评估的比例(它测量分类器在表示实例为正时的正确程度) 。
  • 召回率:计算正确评估的正例示例的比例(测量我们的分类器在面对正例实例时的正确率)。
  • F1 得分:这是精确率和召回率的调和平均值。


目录
相关文章
|
机器学习/深度学习 算法 计算机视觉
使用sklearn进行特征选择
背景 一个典型的机器学习任务,是通过样本的特征来预测样本所对应的值。如果样本的特征少,我们会考虑增加特征。而现实中的情况往往是特征太多了,需要减少一些特征。
|
6月前
鸢尾花数据集分类问题(3)
鸢尾花数据集分类问题
37 2
|
6月前
鸢尾花数据集分类问题(1)
鸢尾花数据集分类问题
45 1
|
6月前
|
机器学习/深度学习
鸢尾花数据集分类问题(2)
鸢尾花数据集分类问题
44 1
|
6月前
鸢尾花数据集分类问题(4)
鸢尾花数据集分类问题
35 0
|
机器学习/深度学习 移动开发 资源调度
机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类
机器学习算法(二): 基于鸢尾花数据集的朴素贝叶斯(Naive Bayes)预测分类
|
机器学习/深度学习 Python
【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类
【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类
247 0
|
机器学习/深度学习 数据可视化
随机森林和KNN分类结果可视化(Sklearn)
随机森林和KNN分类结果可视化(Sklearn)
260 0
|
机器学习/深度学习 Python
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类
755 0
【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类
|
数据采集 机器学习/深度学习 Python
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测
418 0
【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测