使用Python中从头开始构建决策树算法

简介: 决策树(Decision Tree)是一种常见的机器学习算法,被广泛应用于分类和回归任务中。并且再其之上的随机森林和提升树等算法一直是表格领域的最佳模型,所以本文将介绍理解其数学概念,并在Python中动手实现,这可以作为了解这类算法的基础知识。

决策树(Decision Tree)是一种常见的机器学习算法,被广泛应用于分类和回归任务中。并且再其之上的随机森林和提升树等算法一直是表格领域的最佳模型,所以本文将介绍理解其数学概念,并在Python中动手实现,这可以作为了解这类算法的基础知识。

在深入研究代码之前,我们先要了解支撑决策树的数学概念:熵和信息增益

熵:杂质的量度

熵作为度量来量化数据集中的杂质或无序。特别是对于决策树,熵有助于衡量与一组标签相关的不确定性。数学上,数据集S的熵用以下公式计算:

 Entropy(S) = -p_pos * log2(p_pos) - p_neg * log2(p_neg)

P_pos表示数据集中正标签的比例,P_neg表示数据集中负标签的比例。

更高的熵意味着更大的不确定性或杂质,而更低的熵意味着更均匀的数据集。

信息增益:通过拆分提升知识

信息增益是评估通过基于特定属性划分数据集所获得的熵的减少。也就是说它衡量的是执行分割后标签确定性的增加。

数学上,对数据集S中属性a进行分割的信息增益计算如下:

 Information Gain(S, A) = Entropy(S) - ∑ (|S_v| / |S|) * Entropy(S_v)

S 表示原始数据集,A表示要拆分的属性。S_v表示属性A保存值v的S的子集。

目标是通过选择使信息增益最大化的属性,在决策树中创建信息量最大的分割。

在Python中实现决策树算法

有了以上的基础,就可以使用Python从头开始编写Decision Tree算法。

首先导入基本的numpy库,它将有助于我们的算法实现。

 import numpy as np

创建DecisionTree类

 class DecisionTree:
     def __init__(self, max_depth=None):
         self.max_depth = max_depth

定义了DecisionTree类来封装决策树。max_depth参数是树的最大深度,以防止过拟合。

 def fit(self, X, y, depth=0):
         n_samples, n_features = X.shape
         unique_classes = np.unique(y)

         # Base cases
         if (self.max_depth is not None and depth >= self.max_depth) or len(unique_classes) == 1:
             self.label = unique_classes[np.argmax(np.bincount(y))]
             return

拟合方法是决策树算法的核心。它需要训练数据X和相应的标签,以及一个可选的深度参数来跟踪树的深度。我们以最简单的方式处理树的生长:达到最大深度或者遇到纯类。

确定最佳分割属性,循环遍历所有属性以找到信息增益最大化的属性。_information_gain方法(稍后解释)帮助计算每个属性的信息增益。

 best_attribute = None
 best_info_gain = -1
 for feature in range(n_features):
             info_gain = self._information_gain(X, y, feature)
             if info_gain > best_info_gain:
                 best_info_gain = info_gain
                 best_attribute = feature

处理不分割属性,如果没有属性产生正的信息增益,则将类标签分配为节点的标签。

 if best_attribute is None:
             self.label = unique_classes[np.argmax(np.bincount(y))]
             return

分割和递归调用,下面代码确定了分割的最佳属性,并创建两个子节点。根据属性的阈值将数据集划分为左右两个子集。

 self.attribute = best_attribute
 self.threshold = np.median(X[:, best_attribute])

 left_indices = X[:, best_attribute] <= self.threshold
     right_indices = ~left_indices

     self.left = DecisionTree(max_depth=self.max_depth)
     self.right = DecisionTree(max_depth=self.max_depth)

     self.left.fit(X[left_indices], y[left_indices], depth + 1)
     self.right.fit(X[right_indices], y[right_indices], depth + 1)

并且通过递归调用左子集和右子集的fit方法来构建子树。

预测方法使用训练好的决策树进行预测。如果到达一个叶节点(带有标签的节点),它将叶节点的标签分配给X中的所有数据点。

 def predict(self, X):
         if hasattr(self, 'label'):
             return np.array([self.label] * X.shape[0])

当遇到非叶节点时,predict方法根据属性阈值递归遍历树的左子树和右子树。来自双方的预测被连接起来形成最终的预测数组。

 is_left = X[:, self.attribute] <= self.threshold
         left_predictions = self.left.predict(X[is_left])
         right_predictions = self.right.predict(X[~is_left])

         return np.concatenate((left_predictions, right_predictions))

下面两个方法是决策树的核心代码,并且可以使用不同的算法来进行计算,比如ID3 算法使用信息增益作为特征选择的标准,该标准度量了将某特征用于划分数据后,对分类结果的不确定性减少的程度。算法通过递归地选择信息增益最大的特征来构建决策树,也就是我们现在要演示的算法。

_information_gain方法计算给定属性的信息增益。它计算分裂后子熵的加权平均值,并从父熵中减去它。

 def _information_gain(self, X, y, feature):
         parent_entropy = self._entropy(y)

         unique_values = np.unique(X[:, feature])
         weighted_child_entropy = 0

         for value in unique_values:
             is_value = X[:, feature] == value
             child_entropy = self._entropy(y[is_value])
             weighted_child_entropy += (np.sum(is_value) / len(y)) * child_entropy

         return parent_entropy - weighted_child_entropy

熵的计算

 def _entropy(self, y):
         _, counts = np.unique(y, return_counts=True)
         probabilities = counts / len(y)
         return -np.sum(probabilities * np.log2(probabilities))

_entropy方法计算数据集y的熵,它计算每个类的概率,然后使用前面提到的公式计算熵。

常见的算法还有:

C4.5 是 ID3 的改进版本,C4.5 算法在特征选择时使用信息增益比,这是对信息增益的一种归一化,用于解决信息增益在选择特征时偏向于取值较多的特征的问题。

CART 与 ID3 和 C4.5 算法不同,CART(Classification And Regression Tree)又被称为分类回归树,算法采用基尼不纯度(Gini impurity)来度量节点的不确定性,该不纯度度量了从节点中随机选取两个样本,它们属于不同类别的概率。

ID3、C4.5 和 CART 算法都是基于决策树的经典算法,像Xgboost就是使用的CART 作为基础模型。

总结

以上就是使用Python中构造了一个完整的决策树算法的全部。决策树的核心思想是根据数据的特征逐步进行划分,使得每个子集内的数据尽量属于同一类别或具有相似的数值。在构建决策树时,通常会使用一些算法来选择最佳的特征和分割点,以达到更好的分类或预测效果。

https://avoid.overfit.cn/post/212f6b68e0d441c2b3db40304e637e32

作者:Matteo Possamai

目录
相关文章
|
13天前
|
机器学习/深度学习 数据挖掘 Python
Python编程入门——从零开始构建你的第一个程序
【10月更文挑战第39天】本文将带你走进Python的世界,通过简单易懂的语言和实际的代码示例,让你快速掌握Python的基础语法。无论你是编程新手还是想学习新语言的老手,这篇文章都能为你提供有价值的信息。我们将从变量、数据类型、控制结构等基本概念入手,逐步过渡到函数、模块等高级特性,最后通过一个综合示例来巩固所学知识。让我们一起开启Python编程之旅吧!
|
3天前
|
算法 数据安全/隐私保护 开发者
马特赛特旋转算法:Python的随机模块背后的力量
马特赛特旋转算法是Python `random`模块的核心,由松本真和西村拓士于1997年提出。它基于线性反馈移位寄存器,具有超长周期和高维均匀性,适用于模拟、密码学等领域。Python中通过设置种子值初始化状态数组,经状态更新和输出提取生成随机数,代码简单高效。
|
21天前
|
弹性计算 数据管理 数据库
从零开始构建员工管理系统:Python与SQLite3的完美结合
本文介绍如何使用Python和Tkinter构建一个图形界面的员工管理系统(EMS)。系统包括数据库设计、核心功能实现和图形用户界面创建。主要功能有查询、添加、删除员工信息及统计员工数量。通过本文,你将学会如何结合SQLite数据库进行数据管理,并使用Tkinter创建友好的用户界面。
从零开始构建员工管理系统:Python与SQLite3的完美结合
|
8天前
|
存储 API 数据库
使用Python和Flask构建简单的RESTful API
使用Python和Flask构建简单的RESTful API
|
13天前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
垃圾识别分类系统。本系统采用Python作为主要编程语言,通过收集了5种常见的垃圾数据集('塑料', '玻璃', '纸张', '纸板', '金属'),然后基于TensorFlow搭建卷积神经网络算法模型,通过对图像数据集进行多轮迭代训练,最后得到一个识别精度较高的模型文件。然后使用Django搭建Web网页端可视化操作界面,实现用户在网页端上传一张垃圾图片识别其名称。
46 0
基于Python深度学习的【垃圾识别系统】实现~TensorFlow+人工智能+算法网络
|
13天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
48 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
13天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
55 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
18天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
39 3
|
18天前
|
开发框架 前端开发 JavaScript
利用Python和Flask构建轻量级Web应用的实战指南
利用Python和Flask构建轻量级Web应用的实战指南
52 2
|
18天前
|
机器学习/深度学习 JSON API
Python编程实战:构建一个简单的天气预报应用
Python编程实战:构建一个简单的天气预报应用
33 1

热门文章

最新文章

下一篇
无影云桌面