基于sklearn决策树算法对鸢尾花数据进行分类

简介: 基于sklearn决策树算法对鸢尾花数据进行分类

决策树

       决策树是属于有监督机器学习的一种,起源非常早,符合直觉并且非常直观,模仿人类做决 策的过程,早期人工智能模型中有很多应用,现在更多的是使用基于决策树的一些集成学习 的算法。这章我们把决策树算法理解透彻非常有利于后面去学习集成学习。


特点

1. 可以处理非线性的问题


2. 可解释性强 没有θ


3. 模型简单,模型预测效率高 if else


4. 不容易显示的使用函数表达,不可微


决策树模型生成和预测

模型生成:通过大量数据生成一颗非常好的树,用这棵树来预测新来的数据


预测:来一条新数据,按照生成好的树的标准,落到某一个节点上


生成决策树所需分裂指标

常用分裂条件

Gini 系数(CART)

基尼系数是指国际上通用的、用以衡量一个国家或地区居民收入差距的常用指标。 基尼系数最大为“1”,最小等于“0”。基尼系数越接近 0 表明收入分配越是趋向平 等。国际惯例把 0.2 以下视为收入绝对平均,0.2-0.3 视为收入比较平均;0.3-0.4 视为收入 相对合理;0.4-0.5 视为收入差距较大,当基尼系数达到 0.5 以上时,则表示收入悬殊。 基尼指数最早由意大利统计与社会学家 Corrado Gini 在 1912 年提出。 其具体含义是指,在全部居民收入中,用于进行不平均分配的那部分收入所占的比例。 基尼系数最大为“1”,最小等于“0”。前者表示居民之间的收入分配绝对不平均,即 100% 的收入被一个单位的人全部占有了;而后者则表示居民之间的收入分配绝对平均,即人与人之间收入完全平等,没有任何差异。但这两种情况只是在理论上的绝对化形式,在实际生活 中一般不会出现。因此,基尼系数的实际数值只能介于 0~1 之间,基尼系数越小收入分配 越平均,基尼系数越大收入分配越不平均。国际上通常把 0.4 作为贫富差距的警戒线,大于 这一数值容易出现社会动荡。


信息增益(ID3)

在信息论里熵叫作信息量,即熵是对不确定性的度量。从控制论的角度来看,应叫不确定性。信息论的创始人香农在其著作《通信的数学理论》中提出了建立在概率统计模型上的 信息度量。他把信息定义为“用来消除不确定性的东西”。在信息世界,熵越高,则能传输 越多的信息,熵越低,则意味着传输的信息越少。还是举例说明,假设 Kathy 在买衣服的 时候有颜色,尺寸,款式以及设计年份四种要求,而 North 只有颜色和尺寸的要求,那么 在购买衣服这个层面上 Kathy 由于选择更多因而不确定性因素更大,最终 Kathy 所获取的 信息更多,也就是熵更大。所以信息量=熵=不确定性,通俗易懂。在叙述决策树时我们用 熵表示不纯度(Impurity)。

信息增益:分裂前的信息熵 减去 分裂后的信息熵 一个分裂导致的信息增益越大,代表这次分裂提升的纯度越高


信息增益率(C4.5)

对于多叉树,如果不限制分裂多少支,一次分裂就可以将信息熵降为 0,比如 ID3


如何平衡分裂情况与信息增益? 信息增益率:信息增益 除以 类别 本身的


MSE

用于回归树

经典决策树算法

决策树优缺点

优点


1. 决策过程接近人的思维习惯。


2. 模型容易解释,比线性模型具有更好的解释性。


3. 能清楚地使用图形化描述模型。


4. 处理定型特征比较容易。


缺点


1. 一般来说,决策树学习方法的准确率不如其他的模型。针对这种情况存在一些解决方案, 在后面的文章中为大家讲解。


2. 不支持在线学习。当有新样本来的时候,需要重建决策树。


3. 容易产生过拟合现象。

经典决策树算法

ID3 和 C4.5 比较

ID3(Iterative Dichotomiser 3,迭代二叉树 3 代)由 Ross Quinlan 于 1986 年提出。 1993 年,他对 ID3 进行改进设计出了 C4.5 算法。 我们已经知道 ID3 与 C4.5 的不同之处在于,ID3 根据信息增益选取特征构造决策树, 而 C4.5 则是以信息增益率为核心构造决策树。既然 C4.5 是在 ID3 的基础上改进得到的, 那么这两者的优缺点分别是什么? 使用信息增益会让 ID3 算法更偏向于选择值多的属性。信息增益反映给定一个条件后 不确定性减少的程度,必然是分得越细的数据集确定性更高,也就是信息熵越小,信息增益 越大。因此,在一定条件下,值多的属性具有更大的信息增益。而 C4.5 则使用信息增益率 选择属性。信息增益率通过引入一个被称作分裂信息(Split information)的项来惩罚取值较 多的属性,分裂信息用来衡量属性分裂数据的广度和均匀性。这样就改进了 ID3 偏向选择 值多属性的缺点。


相对于 ID3 只能处理离散数据,C4.5 还能对连续属性进行处理,具体步骤为:


1. 把需要处理的样本(对应根节点)或样本子集(对应子树)按照连续变量的大小从小到大进 行排序。 2. 假设该属性对应的不同的属性值一共有 N 个,那么总共有 N−1 个可能的候选分割阈值 点,每个候选的分割阈值点的值为上述排序后的属性值中两两前后连续元素的中点,根 据这个分割点把原来连续的属性分成 bool 属性。实际上可以不用检查所有 N−1 个分 割点。(连续属性值比较多的时候,由于需要排序和扫描,会使 C4.5 的性能有所下降。)


3. 用信息增益比率选择最佳划分。


C4.5 其他优点


1)在树的构造过程中可以进行剪枝,缓解过拟合;


2)能够对连续属性进行离散化处理(二 分法);


3)能够对缺失值进行处理; 缺点:构造树的过程需要对数据集进行多次顺序扫描和排序,导致算法低效; 刚才我们提到 信息增益对可取值数目较多的属性有所偏好;而信息增益率对可取值数目较 少的属性有所偏好!OK,两者结合一下就好了!


解决方法:先从候选属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。 而不是大家常说的直接选择信息增益率最高的属性!

代码实战对鸢尾花数据集分类

from six import StringIO
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import matplotlib as mpl
import pydotplus
iris = load_iris()
data = pd.DataFrame(iris.data)
data.columns = iris.feature_names
data['Species'] = load_iris().target
# 准备数据
x = data.iloc[:, 0:4]
y = data.iloc[:, -1]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
# 训练决策树模型
tree_clf = DecisionTreeClassifier(max_depth=8, criterion='gini')
tree_clf.fit(x_train, y_train)
# 用测试集进行预测,得出精确率
y_test_hat = tree_clf.predict(x_test)
print("acc score:", accuracy_score(y_test, y_test_hat))
print(tree_clf.feature_importances_)
# 将决策树保存成图片
dot_data = StringIO()
tree.export_graphviz(
    tree_clf,
    out_file=dot_data,
    feature_names=iris.feature_names[:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
acc score: 1.0
[0.03575134 0.         0.88187037 0.08237829]

我们看出用测试集预测的精确度为100%,特征中,第三个特征的重要程度系数最大,说明它对决策分类起很重要的作用


我们看出决策树的最大深度为6,而我们在前面设置的是8,那么这个超参数应该设置为多少呢,我们通过画图来分析

depth = np.arange(1, 15)
err_list = []
for d in depth:
    print(d)
    clf = DecisionTreeClassifier(criterion='gini', max_depth=d)
    clf.fit(x_train, y_train)
    y_test_hat = clf.predict(x_test)
    result = (y_test_hat == y_test)
    if d == 1:
        print(result)
    err = 1 - np.mean(result)
    print(100 * err)
    err_list.append(err)
    print(d, ' 错误率:%.2f%%' % (100 * err))
mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(facecolor='w')
plt.plot(depth, err_list, 'ro-', lw=2)
plt.xlabel('决策树深度', fontsize=15)
plt.ylabel('错误率', fontsize=15)
plt.title('决策树深度和过拟合', fontsize=18)
plt.grid(True)
plt.show()

从图中我们看出当决策树深度等于3的时候,错误率就以及接近0了


目录
相关文章
|
7天前
|
存储 算法 Java
Java中,树与图的算法涉及二叉树的前序、中序、后序遍历以及DFS和BFS搜索。
【6月更文挑战第21天】Java中,树与图的算法涉及二叉树的前序、中序、后序遍历以及DFS和BFS搜索。二叉树遍历通过访问根、左、右子节点实现。DFS采用递归遍历图的节点,而BFS利用队列按层次访问。以下是简化的代码片段:[Java代码略]
16 4
|
3天前
|
存储 算法 Linux
【数据结构和算法】---二叉树(1)--树概念及结构
【数据结构和算法】---二叉树(1)--树概念及结构
10 0
|
5天前
|
机器学习/深度学习 算法
基于鲸鱼优化的knn分类特征选择算法matlab仿真
**基于WOA的KNN特征选择算法摘要** 该研究提出了一种融合鲸鱼优化算法(WOA)与K近邻(KNN)分类器的特征选择方法,旨在提升KNN的分类精度。在MATLAB2022a中实现,WOA负责优化特征子集,通过模拟鲸鱼捕食行为的螺旋式和包围策略搜索最佳特征。KNN则用于评估特征子集的性能。算法流程包括WOA参数初始化、特征二进制编码、适应度函数定义(以分类准确率为基准)、WOA迭代搜索及最优解输出。该方法有效地结合了启发式搜索与机器学习,优化特征选择,提高分类性能。
|
7天前
|
算法 Java 机器人
Java数据结构与算法:AVL树
Java数据结构与算法:AVL树
|
6天前
|
机器学习/深度学习 算法
梯度提升树GBDT系列算法
在Boosting集成算法当中,我们逐一建立多个弱评估器(基本是决策树),并且下一个弱评估器的建立方式依赖于上一个弱评估器的评估结果,最终综合多个弱评估器的结果进行输出。
|
1天前
|
存储 算法 安全
加密算法概述:分类与常见算法
加密算法概述:分类与常见算法
|
1天前
|
算法
技术好文共享:算法之树表的查找
技术好文共享:算法之树表的查找
|
1天前
|
机器学习/深度学习 算法 数据挖掘
聚类算法:揭秘数据背后的规律
聚类算法:揭秘数据背后的规律
|
4天前
|
存储 算法
【C/数据结构与算法】:树和二叉树
【C/数据结构与算法】:树和二叉树
8 0
|
2天前
|
机器学习/深度学习 自然语言处理 算法
m基于深度学习的OFDM+QPSK链路信道估计和均衡算法误码率matlab仿真,对比LS,MMSE及LMMSE传统算法
**摘要:** 升级版MATLAB仿真对比了深度学习与LS、MMSE、LMMSE的OFDM信道估计算法,新增自动样本生成、复杂度分析及抗频偏性能评估。深度学习在无线通信中,尤其在OFDM的信道估计问题上展现潜力,解决了传统方法的局限。程序涉及信道估计器设计,深度学习模型通过学习导频信息估计信道响应,适应频域变化。核心代码展示了信号处理流程,包括编码、调制、信道模拟、降噪、信道估计和解调。
23 8