在机器学习领域,决策树是一种常见的分类方法,它通过从数据中学习简单的决策规则来预测目标变量。本文将介绍如何使用Python的scikit-learn
库来加载Iris数据集、训练一个决策树模型、评估其准确率,并最终可视化这个模型。
作者介绍:10年大厂数据\经营分析经验,现任大厂数据部门负责人。
会一些的技术:数据分析、算法、SQL、大数据相关、python
欢迎加入社区:码上找工作
作者专栏每日更新:
备注说明:方便大家阅读,统一使用python,带必要注释,公众号 数据分析螺丝钉
1. 加载Iris数据集
Iris数据集是机器学习中最著名的数据集之一,由英国统计学家和生物学家Ronald Fisher在1936年介绍。它包含150个样本,每个样本都是关于鸢尾花的测量数据,包括:
- 花萼长度(Sepal Length)
- 花萼宽度(Sepal Width)
- 花瓣长度(Petal Length)
- 花瓣宽度(Petal Width)
这些样本分属于三个鸢尾花种类,每种50个样本:
- Setosa
- Versicolor
- Virginica
Iris数据集的目的是基于这四个特征预测鸢尾花的种类,它是分类任务中的一个经典问题
2. 训练决策树模型
使用决策树对Iris数据集进行分类首先需要划分数据集为训练集和测试集,这可以通过train_test_split
函数实现,通常我们保留一部分数据(如20%)作为测试集。之后,创建DecisionTreeClassifier
的实例,并调用其fit
方法用训练集训练模型。
它模拟了人类做决策的过程,通过一系列的问题来对数据进行分类。一个决策树包括:
- 节点(Nodes):表示一个特征或属性。
- 边/分支(Edges/Branches):代表决策规则。
- 叶节点(Leaf nodes):代表一个分类或决策的输出结果。
在决策树中,从根节点(最顶部的节点)开始,根据每个节点代表的特征对数据进行分割,直到达到叶节点,叶节点表示最终的分类结果。
3. 评估模型的准确率
模型训练完成后,可以通过预测测试集的标签并与真实标签进行比较来评估模型的性能。accuracy_score
函数能够计算模型预测的准确率,即正确预测的样本占总样本的比例。
4. 可视化决策树
scikit-learn
提供了plot_tree
函数,可以将训练好的决策树模型可视化。这个可视化展示了模型的决策过程,包括决策的条件、树的分支和叶节点等信息,使得模型的决策规则直观易懂。
详细步骤与代码实现
首先,安装必要的库:
pip install scikit-learn matplotlib
然后,通过以下Python代码来实现上述步骤:
# 使用Python实现一个简单的决策树分类器 from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt from sklearn import tree # 加载Iris数据集 iris = load_iris() X = iris.data y = iris.target # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 创建决策树分类器实例 dt_clf = DecisionTreeClassifier(max_depth=3) # 训练决策树分类器 dt_clf.fit(X_train, y_train) # 在测试集上进行预测 y_pred = dt_clf.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test, y_pred) print(f"Accuracy: {accuracy:.2f}") # 可视化决策树 plt.figure(figsize=(20,10)) tree.plot_tree(dt_clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names) plt.show()
可视化呈现
解读
这张图展示了一个训练好的决策树模型,用于Iris数据集的分类问题。每个方框代表树中的一个节点,我们从顶部开始解读:
- 根节点:它显示“petal width (cm) <= 0.8”表示根据花瓣宽度的值决定如何分割数据。节点还提供了“gini = 0.667”,这是基尼不纯度,一个衡量分支质量的指标;“samples = 120”表示该节点包含120个样本;“value = [40, 40, 39]”表示这120个样本中,有40个属于每个种类;而“class = versicolor”表示这个节点中最多的种类是Versicolor。
- 第二层节点:根节点分为两个子节点。
- 左边的子节点是一个叶节点,表示分类为Setosa的条件满足(花瓣宽度小于或等于0.8cm)。它是纯净的,所有40个样本都属于Setosa种类(gini = 0.0)。
- 右边的子节点进一步根据“petal length (cm) <= 4.75”划分。
- 第三层节点:
- 左边显示基于花瓣宽度的进一步划分(小于或等于1.65cm),这导致了一个几乎完全纯净的叶节点,其中36个样本中有35个属于Versicolor,1个属于Virginica。
- 右边的节点是基于花瓣宽度小于或等于1.75cm的另一个划分。
- 第四层节点:展示了两个叶节点。
- 左边的节点是纯净的,只有一个Virginica样本。
- 右边的节点几乎纯净,它有34个Virginica样本和一个Versicolor样本。
在这个决策树中,大多数叶节点的gini系数很低,意味着它们的分类是高度纯净的。从树的结构我们可以得出,花瓣宽度和长度是非常重要的特征,用于区分Iris花的种类。
最终,这个决策树提供了一个决策路径:通过检查花瓣的宽度和长度,我们可以将Iris花分为Setosa、Versicolor或Virginica。这个可视化非常有助于理解模型是如何基于花卉的物理特征做出分类决策的。
欢迎关注微信公众号 数据分析螺丝钉