Python手写决策树并应对过度拟合问题

简介: Python手写决策树并应对过度拟合问题

640.png

介绍

决策树是一种用于监督学习的算法。它使用树结构,其中包含两种类型的节点:决策节点和叶节点。决策节点通过在要素上询问布尔值将数据分为两个分支。叶节点代表一个类。训练过程是关于在具有特定特征的特定特征中找到“最佳”分割。预测过程是通过沿着路径的每个决策节点回答问题来从根到达叶节点。

基尼不纯度和熵

术语“最佳”拆分是指拆分之后,两个分支比任何其他可能的拆分更“有序”。我们如何定义更多有序的?这取决于我们选择哪种指标。通常,度量有两种类型:基尼不纯度和熵。这些指标越小,数据集就越“有序”。

640.png

这两个指标之间的差异非常微妙。但 在大多数应用中,两个指标的行为类似。以下是用于计算每个指标的代码。

defgini_impurity(y):
#calculategini_impuritygivenlabels/classesofeachexamplem=y.shape[0]
cnts=dict(zip(*np.unique(y, return_counts=True)))
impurity=1-sum((cnt/m)**2forcntincnts.values())
returnimpuritydefentropy(y):
#calculateentropygivenlabels/classesofeachexamplem=y.shape[0]
cnts=dict(zip(*np.unique(y, return_counts=True)))
disorder=-sum((cnt/m)*log(cnt/m) forcntincnts.values())
returndisorder

构建树

训练过程实质上是在树上建树。关键步骤是确定“最佳”分配。过程如下:我们尝试按每个功能中的每个唯一值分割数据,然后选择混乱程度最小的最佳数据。现在我们可以将此过程转换为python代码。

defget_split(X, y):
#loopthroughfeaturesandvaluestofindbestcombinationwiththemostinformationgainbest_gain, best_index, best_value=0, None, Nonecur_gini=gini_impurity(y)
n_features=X.shape[1]  
forindexinrange(n_features):  
values=np.unique(X[:, index], return_counts=False)  
forvalueinvalues:  
left, right=test_split(index, value, X, y)
ifleft['y'].shape[0] ==0orright['y'].shape[0] ==0:
continuegain=info_gain(left['y'], right['y'], cur_gini)
ifgain>best_gain:
best_gain, best_index, best_value=gain, index, valuebest_split= {'gain': best_gain, 'index': best_index, 'value': best_value}
returnbest_splitdeftest_split(index, value, X, y):
#splitagroupofexamplesbasedongivenindex (feature) andvaluemask=X[:, index] <valueleft= {'X': X[mask, :], 'y': y[mask]}
right= {'X': X[~mask, :], 'y': y[~mask]}
returnleft, rightdefinfo_gain(l_y, r_y, cur_gini):
#calculatetheinformationgainforacertainsplitm, n=l_y.shape[0], r_y.shape[0]
p=m/ (m+n)
returncur_gini-p*gini_impurity(l_y) - (1-p) *gini_impurity(r_y)

在构建树之前,我们先定义决策节点和叶节点。决策节点指定将在其上拆分的特征和值。它还指向左,右子项。叶节点包括类似于Counter对象的字典,该字典显示每个类有多少训练示例。这对于计算训练的准确性很有用。另外,它导致到达该叶子的每个示例的结果预测。

classDecision_Node:
#defineadecisionnodedef__init__(self, index, value, left, right):
self.index, self.value=index, valueself.left, self.right=left, rightclassLeaf:
#definealeafnodedef__init__(self, y):
self.counts=dict(zip(*np.unique(y, return_counts=True)))
self.prediction=max(self.counts.keys(), key=lambdax: self.counts[x])

鉴于其结构,通过递归构造树是最方便的。递归的出口是叶节点。当我们无法通过拆分提高数据纯度时,就会发生这种情况。如果我们可以找到“最佳”拆分,则这将成为决策节点。然后,我们对其左,右子级递归执行相同的操作。

defdecision_tree(X, y):
#trainthedecisiontreemodelwithadatasetcorrect_prediction=0defbuild_tree(X, y):
#recursivelybuildthetreesplit=get_split(X, y)
ifsplit['gain'] ==0:
nonlocalcorrect_predictionleaf=Leaf(y)
correct_prediction+=leaf.counts[leaf.prediction]
returnleafleft, right=test_split(split['index'], split['value'], X, y)
left_node=build_tree(left['X'], left['y'])
right_node=build_tree(right['X'], right['y'])
returnDecision_Node(split['index'], split['value'], left_node, right_node)
root=build_tree(X, y)
returncorrect_prediction/y.shape[0], root

预测

现在我们可以遍历树直到叶节点来预测一个示例。

defpredict(x, node):
ifisinstance(node, Leaf):
returnnode.predictionifx[node.index] <node.value:
returnpredict(x, node.left)
else:
returnpredict(x, node.right)

事实证明,训练精度为100%,决策边界看起来很奇怪!显然,该模型过度拟合了训练数据。好吧,如果考虑到这一点,如果我们继续拆分直到数据集变得更纯净,决策树将过度适合数据。换句话说,如果我们不停止分裂,该模型将正确分类每个示例!训练准确性为100%(除非具有完全相同功能的不同类别的示例),这丝毫不令人惊讶。

640.png

如何应对过度拟合?

从上一节中,我们知道决策树过拟合的幕后原因。为了防止过度拟合,我们需要在某个时候停止拆分树。因此,我们需要引入两个用于训练的超参数。它们是:树的最大深度和叶子的最小尺寸。让我们重写树的构建部分。

defdecision_tree(X, y, max_dep=5, min_size=10):
#trainthedecisiontreemodelwithadatasetcorrect_prediction=0defbuild_tree(X, y, dep, max_dep=max_dep, min_size=min_size):
#recursivelybuildthetreesplit=get_split(X, y)
ifsplit['gain'] ==0ordep>=max_depory.shape[0] <=min_size:
nonlocalcorrect_predictionleaf=Leaf(y)
correct_prediction+=leaf.counts[leaf.prediction]
returnleafleft, right=test_split(split['index'], split['value'], X, y)
left_node=build_tree(left['X'], left['y'], dep+1)
right_node=build_tree(right['X'], right['y'], dep+1)
returnDecision_Node(split['index'], split['value'], left_node, right_node)
root=build_tree(X, y, 0)
returncorrect_prediction/y.shape[0], root

现在我们可以重新训练数据并绘制决策边界。

640.png

树的可视化

接下来,我们将通过打印出决策树的节点来可视化决策树。节点的压痕与其深度成正比。

defprint_tree(node, indent="|---"):
#printthetreeifisinstance(node, Leaf):
print(indent+'Class:', node.prediction)
returnprint(indent+'feature_'+str(node.index) +' <= '+str(round(node.value, 2)))
print_tree(node.left, '|   '+indent)
print(indent+'feature_'+str(node.index) +' > '+str(round(node.value, 2)))
print_tree(node.right, '|   '+indent)

结果如下:

|---feature_1<=1.87||---feature_1<=-0.74|||---feature_1<=-1.79||||---feature_1<=-2.1|||||---Class: 2||||---feature_1>-2.1|||||---Class: 2|||---feature_1>-1.79||||---feature_0<=1.62|||||---feature_0<=-1.31||||||---Class: 2|||||---feature_0>-1.31||||||---feature_1<=-1.49|||||||---Class: 1||||||---feature_1>-1.49|||||||---Class: 1||||---feature_0>1.62|||||---Class: 2||---feature_1>-0.74|||---feature_1<=0.76||||---feature_0<=0.89|||||---feature_0<=-0.86||||||---feature_0<=-2.24|||||||---Class: 2||||||---feature_0>-2.24|||||||---Class: 1|||||---feature_0>-0.86||||||---Class: 0||||---feature_0>0.89|||||---feature_0<=2.13||||||---Class: 1|||||---feature_0>2.13||||||---Class: 2|||---feature_1>0.76||||---feature_0<=-1.6|||||---Class: 2||||---feature_0>-1.6|||||---feature_0<=1.35||||||---feature_1<=1.66|||||||---Class: 1||||||---feature_1>1.66|||||||---Class: 1|||||---feature_0>1.35||||||---Class: 2|---feature_1>1.87||---Class: 2

总结

与其他回归模型不同,决策树不使用正则化来对抗过度拟合。相反,它使用树修剪。选择正确的超参数(树的深度和叶子的大小)还需要进行实验,例如 使用超参数矩阵进行交叉验证。

对于完整的工作流,包括数据生成和绘图决策边界,完整的代码在这里:https://github.com/JunWorks/ML-Algorithm-with-Python/blob/master/decision_tree/decision_tree.ipynb



目录
相关文章
|
1天前
|
机器学习/深度学习 数据采集 算法
Python用逻辑回归、决策树、SVM、XGBoost 算法机器学习预测用户信贷行为数据分析报告
Python用逻辑回归、决策树、SVM、XGBoost 算法机器学习预测用户信贷行为数据分析报告
|
1天前
|
JSON 数据可视化 Shell
数据结构可视化 Graphviz在Python中的使用 [树的可视化]
数据结构可视化 Graphviz在Python中的使用 [树的可视化]
11 0
|
1天前
|
计算机视觉 Python
使用Python进行多点拟合以确定标准球的球心坐标
使用Python进行多点拟合以确定标准球的球心坐标
15 1
|
1天前
|
SQL 分布式计算 数据可视化
数据分享|Python、Spark SQL、MapReduce决策树、回归对车祸发生率影响因素可视化分析
数据分享|Python、Spark SQL、MapReduce决策树、回归对车祸发生率影响因素可视化分析
|
1天前
|
机器学习/深度学习 算法 Python
【Python机器学习专栏】机器学习中的过拟合与欠拟合
【4月更文挑战第30天】机器学习中,模型性能受数据、算法及复杂度影响。过拟合(训练数据学得太好,泛化能力弱)和欠拟合(模型太简单,无法准确预测)是常见问题。理解两者概念、原因、影响及检测方法对构建有效模型至关重要。解决策略包括增加数据量、简化模型、添加特征或选择更复杂模型。使用交叉验证等工具可帮助检测和缓解过拟合、欠拟合。
|
1天前
|
机器学习/深度学习 算法 数据可视化
【Python机器学习专栏】决策树算法的实现与解释
【4月更文挑战第30天】本文探讨了决策树算法,一种流行的监督学习方法,用于分类和回归。文章阐述了决策树的基本原理,其中内部节点代表特征判断,分支表示判断结果,叶节点代表类别。信息增益等标准用于衡量特征重要性。通过Python的scikit-learn库展示了构建鸢尾花数据集分类器的示例,包括训练、预测、评估和可视化决策树。最后,讨论了模型解释和特征重要性评估在优化中的作用。
|
1天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
|
1天前
|
资源调度 数据可视化 数据挖掘
Python用PyMC贝叶斯GLM广义线性模型、NUTS采样器拟合、后验分布可视化
Python用PyMC贝叶斯GLM广义线性模型、NUTS采样器拟合、后验分布可视化
|
1天前
|
机器学习/深度学习 算法 Python
PYTHON银行机器学习:回归、随机森林、KNN近邻、决策树、高斯朴素贝叶斯、支持向量机SVM分析营销活动数据|数据分享(下)
PYTHON银行机器学习:回归、随机森林、KNN近邻、决策树、高斯朴素贝叶斯、支持向量机SVM分析营销活动数据|数据分享
|
1天前
|
机器学习/深度学习 算法 数据挖掘
PYTHON银行机器学习:回归、随机森林、KNN近邻、决策树、高斯朴素贝叶斯、支持向量机SVM分析营销活动数据|数据分享(上)
PYTHON银行机器学习:回归、随机森林、KNN近邻、决策树、高斯朴素贝叶斯、支持向量机SVM分析营销活动数据|数据分享