Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)

简介: Machine Learning机器学习之决策树算法 Decision Tree(附Python代码)

前言:

决策树是一种经典的机器学习算法,用于解决分类和回归问题。它的基本思想通过对数据集中的特征进行递归划分,构建一系列的决策规则,从而生成一个树状结构。在决策树中,每个内部节点表示对输入特征的一个测试,每个分支代表一个测试结果,而每个叶子节点表示一个类别或输出值。

决策树的发展历史可以追溯到20世纪50年代和60年代。最早的决策树算法是ID3(Iterative Dichotomiser 3),由Ross Quinlan于1986年提出。之后,C4.5算法和其改进版本C5.0也相继提出,扩展了ID3算法并加入了剪枝等优化方法。此外,还有 CART(Classification and Regression Trees)算法,由Leo Breiman等人于1984年提出,可用于分类和回归问题,并引入了基于基尼系数(Gini impurity)和均方误差(Mean Squared Error)的划分准则。

策树在机器学习领域得到了广泛的应用,具有许多优点,如易于理解、可解释性强、能够处理混合数据类型等。它适用于多种任务,包括分类、回归、特征选择等。此外,决策树还可以通过集成学习方法(如随机森林、梯度提升树)进一步提升性能,并解决决策树容易过拟合的问题。

总的来说,决策树是一种简单而有效的机器学习算法,为解决分类和回归问题提供了一种直观的方法。随着机器学习领域的发展,决策树算法也在不断地被改进和优化,为各种实际问题提供了强大的工具。

一、决策树思想

决策树的思想原理是通过对数据集中的特征进行递归划分,构建一系列的决策规则,从而生成一个树状结构。其基本思想可以总结如下:

  1. 选择最佳特征: 首先,从数据集中选择一个最佳的特征作为当前节点的划分标准。通常使用一些准则来评估特征的优劣,例如信息增益、基尼系数、均方误差等。
  2. 划分数据集:数据集根据选择的特征进行划分,生成多个子集,每个子集包含具有相同特征值的样本。
  3. 递归构建子树: 对每个子集递归地重复步骤1和步骤2,直到满足停止条件。停止条件可以是节点中样本的类别相同、达到最大深度、样本数量小于某个阈值等。
  4. 生成决策规则: 最终,每个叶子节点都表示一个类别或输出值,而每个内部节点都表示对输入特征的一个测试。通过将树的结构转化为一系列的if-then规则,可以解释数据的分类或预测过程。
  5. 剪枝优化(可选): 对生成的决策树进行剪枝优化,去除一些不必要的节点,防止过拟合。

通过这种方式,决策树可以根据输入特征对数据进行逐层的划分,构建出一个易于理解和解释的决策模型。决策树的基本思想是根据数据的特征值进行划分,通过划分后的数据集的纯度或者信息增益来选择最佳的划分特征,从而递归地构建出一个树状结构,实现对数据的分类或预测。

                                      开始
                                       |
                                       V
                             选择最佳特征作为根节点
                                       |
                                       V
              划分数据集,生成子集,选择最佳特征作为当前节点的划分标准
                     /              |              \
                    /               |               \
                   /                |                \
  子集1满足停止条件?           子集2满足停止条件?           子集3满足停止条件?
     /         |         \            /         |         \    
    /          |          \          /          |          \
   V           V           V        V           V           V
  生成叶子节点   递归构建子树    生成叶子节点   递归构建子树     生成叶子节点
     |               |               |               |               |
     V               V               V               V               V
    返回             返回             返回             返回             返回
     |               |               |               |               |
     V               V               V               V               V
                                      结束

二、经典决策树算法

经典的决策树算法包括ID3(Iterative Dichotomiser 3)、C4.5(Classification and Regression Trees)以及CART(Classification and Regression Trees)。这些算法在构建决策树时采用了不同的思想和策略,下面简要介绍它们的思想和实现步骤:

  1. ID3(Iterative Dichotomiser 3)
  • 思想: ID3算法是一种基于信息熵的决策树算法,它通过选择使得信息增益最大的特征来进行划分,以减少数据集的不确定性。
  • 实现步骤:
  1. 从所有特征中选择使得信息增益最大的特征作为当前节点的划分标准。
  2. 根据选定的特征进行划分,生成子集。
  3. 对每个子集递归地重复步骤1和步骤2,直到满足停止条件。
  4. 生成叶子节点,表示类别或输出值。
  5. 返回。
  1. C4.5(Classification and Regression Trees)
  • 思想: C4.5算法是ID3的改进版本,它在选择划分特征时采用信息增益比来解决ID3算法对取值数目较多特征的偏好。
  • 实现步骤:
  1. 从所有特征中选择使得信息增益比最大的特征作为当前节点的划分标准。
  2. 根据选定的特征进行划分,生成子集。
  3. 对每个子集递归地重复步骤1和步骤2,直到满足停止条件。
  4. 生成叶子节点,表示类别或输出值。
  5. 返回。
  1. CART(Classification and Regression Trees)
  • 思想: CART算法是一种同时适用于分类和回归问题的决策树算法,它通过选择使得基尼系数最小的特征来进行划分,以提高树的纯度。
  • 实现步骤:
  1. 从所有特征中选择使得基尼系数最小的特征作为当前节点的划分标准。
  2. 根据选定的特征进行划分,生成子集。
  3. 对每个子集递归地重复步骤1和步骤2,直到满足停止条件。
  4. 生成叶子节点,表示类别或输出值。
  5. 返回。

这些经典的决策树算法在实现时都采用了递归的思想,通过选择最佳的划分特征来构建树结构,直到满足停止条件为止。每个算法在选择划分特征时都采用了不同的指标,如信息增益、信息增益比、基尼系数等,以达到不同的优化目标。

三、算法应用案列

基于Python 和 Scikit-learn 库实现决策树算法的简单示例代码,用于解决分类问题:

首先我们将使用鸢尾花数据集,并尝试根据花萼和花瓣的长度和宽度来预测鸢尾花的品种。

第一步是加载了鸢尾花数据集,并选择花萼长度和花瓣长度作为特征。然后将数据分为训练集和测试集,并创建了一个决策树模型并在训练集上拟合了模型。最后,使用Matplotlib绘制了训练集和测试集的数据点,并在图上绘制了决策边界。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
 
# 加载鸢尾花数据集
iris = load_iris()
 
# 选择花萼长度和花瓣长度作为特征
X = iris.data[:, [0, 2]]
y = iris.target
 
# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.29, random_state=42)
 
# 创建决策树模型
model = DecisionTreeClassifier()
 
# 在训练集上拟合模型
model.fit(X_train, y_train)
 
# 在测试集上进行预测
y_pred = model.predict(X_test)
 
# 计算模型的准确率
accuracy = accuracy_score(y_test, y_pred)
print("模型的准确率:", accuracy)
 
# 打印分类报告
print("分类报告:")
print(classification_report(y_test, y_pred))
 
# 绘制数据变化图
plt.figure(figsize=(10, 6))
 
# 绘制训练集数据点
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', label='Training Set')
 
# 绘制测试集数据点
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap='viridis', marker='x', label='Test Set')
 
# 绘制决策边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                     np.arange(y_min, y_max, 0.01))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
 
plt.xlabel('Sepal Length (cm)')
plt.ylabel('Petal Length (cm)')
plt.title('Decision Tree Classifier - Iris Dataset')
 
plt.legend()
plt.colorbar(label='Target Class')
 
plt.show()

执行结果:数据集划分29%测试集,71%训练集。精确率约为95%

四、总结

算法

  • 原理简单直观: 决策树基于对数据集中特征的递归划分,生成一系列的决策规则,形成树状结构,易于理解和解释。
  • 可解释性强: 决策树模型生成的规则可以直观地解释为基于哪些特征进行分类或预测,为决策过程提供了透明度。
  • 能够处理混合数据类型: 决策树算法能够处理包括连续型和离散型特征在内的多种数据类型。
  • 适用于多种任务: 决策树可用于分类和回归问题,并且能够进行特征选择和缺失值处理等任务。
  • 可扩展性好: 决策树可以与其他算法结合,如随机森林和梯度提升树等,以提高预测性能。

决策树算法应用:

  • 医疗诊断: 决策树可用于根据患者的症状和体征进行医学诊断,帮助医生做出治疗决策。
  • 金融风险评估: 决策树可用于根据借款人的信用记录和财务状况来评估贷款风险,并决定是否批准贷款。
  • 市场营销: 决策树可用于分析客户的行为和偏好,帮助企业制定个性化的营销策略。
  • 生态学研究: 决策树可用于分析生态系统中不同因素之间的关系,帮助科学家理解生态系统的结构和功能。

决策树算法优缺点:

优点:

  1. 简单直观,易于理解和解释。
  2. 可解释性强,生成的规则直观可见。
  3. 能够处理混合数据类型,包括连续型和离散型特征。
  4. 适用于多种任务,包括分类、回归、特征选择等。
  5. 可扩展性好,能够与其他算法结合提高预测性能。

缺点:

  1. 容易过拟合,特别是在处理复杂数据集时。
  2. 对于类别数量较多的特征,决策树倾向于选择类别数较多的特征进行划分。
  3. 不稳定性高,对输入数据的小变化可能会导致树结构的大变化。
  4. 在处理连续型数据时可能产生过于复杂的树结构,需要进行剪枝等操作来减少模型复杂度。


目录
打赏
0
0
0
0
17
分享
相关文章
|
3天前
|
时间序列异常检测:MSET-SPRT组合方法的原理和Python代码实现
MSET-SPRT是一种结合多元状态估计技术(MSET)与序贯概率比检验(SPRT)的混合框架,专为高维度、强关联数据流的异常检测设计。MSET通过历史数据建模估计系统预期状态,SPRT基于统计推断判定偏差显著性,二者协同实现精准高效的异常识别。本文以Python为例,展示其在模拟数据中的应用,证明其在工业监控、设备健康管理及网络安全等领域的可靠性与有效性。
413 7
时间序列异常检测:MSET-SPRT组合方法的原理和Python代码实现
【Azure Developer】分享两段Python代码处理表格(CSV格式)数据 : 根据每列的内容生成SQL语句
本文介绍了使用Python Pandas处理数据收集任务中格式不统一的问题。针对两种情况:服务名对应多人拥有状态(1/0表示),以及服务名与人名重复列的情况,分别采用双层for循环和字典数据结构实现数据转换,最终生成Name对应的Services列表(逗号分隔)。此方法高效解决大量数据的人工处理难题,减少错误并提升效率。文中附带代码示例及执行结果截图,便于理解和实践。
机器学习特征筛选:向后淘汰法原理与Python实现
向后淘汰法(Backward Elimination)是机器学习中一种重要的特征选择技术,通过系统性地移除对模型贡献较小的特征,以提高模型性能和可解释性。该方法从完整特征集出发,逐步剔除不重要的特征,最终保留最具影响力的变量子集。其优势包括提升模型简洁性和性能,减少过拟合,降低计算复杂度。然而,该方法在高维特征空间中计算成本较高,且可能陷入局部最优解。适用于线性回归、逻辑回归等统计学习模型。
81 7
如何在Python下实现摄像头|屏幕|AI视觉算法数据的RTMP直播推送
本文详细讲解了在Python环境下使用大牛直播SDK实现RTMP推流的过程。从技术背景到代码实现,涵盖Python生态优势、AI视觉算法应用、RTMP稳定性及跨平台支持等内容。通过丰富功能如音频编码、视频编码、实时预览等,结合实际代码示例,为开发者提供完整指南。同时探讨C接口转换Python时的注意事项,包括数据类型映射、内存管理、回调函数等关键点。最终总结Python在RTMP推流与AI视觉算法结合中的重要性与前景,为行业应用带来便利与革新。
基于 Python 哈希表算法的员工上网管理策略研究
于当下数字化办公环境而言,员工上网管理已成为企业运营管理的关键环节。企业有必要对员工的网络访问行为予以监控,以此确保信息安全并提升工作效率。在处理员工上网管理相关数据时,适宜的数据结构与算法起着举足轻重的作用。本文将深入探究哈希表这一数据结构在员工上网管理场景中的应用,并借助 Python 代码示例展开详尽阐述。
31 3
Python下的毫秒级延迟RTSP|RTMP播放器技术探究和AI视觉算法对接
本文深入解析了基于Python实现的RTSP/RTMP播放器,探讨其代码结构、实现原理及优化策略。播放器通过大牛直播SDK提供的接口,支持低延迟播放,适用于实时监控、视频会议和智能分析等场景。文章详细介绍了播放控制、硬件解码、录像与截图功能,并分析了回调机制和UI设计。此外,还讨论了性能优化方法(如硬件加速、异步处理)和功能扩展(如音量调节、多格式支持)。针对AI视觉算法对接,文章提供了YUV/RGB数据处理示例,便于开发者在Python环境下进行算法集成。最终,播放器凭借低延迟、高兼容性和灵活扩展性,为实时交互场景提供了高效解决方案。
Python 高级编程与实战:深入理解数据科学与机器学习
本文深入探讨了Python在数据科学与机器学习中的应用,介绍了pandas、numpy、matplotlib等数据科学工具,以及scikit-learn、tensorflow、keras等机器学习库。通过实战项目,如数据可视化和鸢尾花数据集分类,帮助读者掌握这些技术。最后提供了进一步学习资源,助力提升Python编程技能。
探秘文件共享服务之哈希表助力 Python 算法实现
在数字化时代,文件共享服务不可或缺。哈希表(散列表)通过键值对存储数据,利用哈希函数将键映射到特定位置,极大提升文件上传、下载和搜索效率。例如,在大型文件共享平台中,文件名等信息作为键,物理地址作为值存入哈希表,用户检索时快速定位文件,减少遍历时间。此外,哈希表还用于文件一致性校验,确保传输文件未被篡改。以Python代码示例展示基于哈希表的文件索引实现,模拟文件共享服务的文件索引构建与检索功能。哈希表及其分布式变体如一致性哈希算法,保障文件均匀分布和负载均衡,持续优化文件共享服务性能。
|
24天前
|
公司电脑网络监控场景下 Python 广度优先搜索算法的深度剖析
在数字化办公时代,公司电脑网络监控至关重要。广度优先搜索(BFS)算法在构建网络拓扑、检测安全威胁和优化资源分配方面发挥重要作用。通过Python代码示例展示其应用流程,助力企业提升网络安全与效率。未来,更多创新算法将融入该领域,保障企业数字化发展。
42 10
|
25天前
|
基于 Python 广度优先搜索算法的监控局域网电脑研究
随着局域网规模扩大,企业对高效监控计算机的需求增加。广度优先搜索(BFS)算法凭借其层次化遍历特性,在Python中可用于实现局域网内的计算机设备信息收集、网络连接状态监测及安全漏洞扫描,确保网络安全与稳定运行。通过合理选择数据结构与算法,BFS显著提升了监控效能,助力企业实现智能化的网络管理。
28 7

热门文章

最新文章