快速入门Python机器学习(18)

简介: 快速入门Python机器学习(18)

9 决策树(Decision Tree)


9. 1 决策树原理

image.png


9.2 信息增益与基尼不纯度

信息熵(约翰·香农 1948《通信的数学原理》,一个问题不确定性越大,需要获取的信息就越多,信息熵就越大;一个问题不确定性越小,需要获取的信息就越少,信息熵就越小)


集合D中第k类样本的比率为pk,(k=1,2,…|y|)

image.png


信息增益(Information Gain):划分数据前后数据信息熵的差值。信息增益纯度越高,纯度提升越大;信息增益纯度越低,纯度提升越小。


基尼不纯度

image.png


基尼不纯度反映从集合D中随机取两个样本后,其类别不一致性的概率。


方法

算法

信息增益

ID3(改进C4.5)

基尼不纯度

CART


9.3 决策树分类(Decision Tree Classifier

9.3.1类、属性和方法


class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, ccp_alpha=0.0)


参数

属性

类型

解释

max_depth

int, default=None

树的最大深度。如果没有,则节点将展开,直到所有叶都是纯的,或者直到所有叶都包含少于min_samples_split samples的值。

criterion

{'gini''entropy'}, default='gini'

测量分割质量的函数。支持的标准是基尼杂质的'基尼'和信息增益的''


属性

属性

解释

classes_

ndarray of shape (n_classes,) or list of ndarray类标签(单输出问题)或类标签数组列表(多输出问题)。

feature_importances_

ndarray of shape (n_features,)返回功能重要性。

max_features_

intmax_features的推断值。

n_classes_

int or list of int类的数量(对于单个输出问题),或包含每个输出的类的数量的列表(对于多输出问题)。

n_features_

int执行拟合时的特征数。

n_outputs_

int执行拟合时的输出数。

tree_

Tree instance树实例基础树对象。请参阅帮助(sklearn.tree._tree.Tree)对于树对象的属性,了解决策树结构对于这些属性的基本用法。


方法

apply(X[, check_input])

返回每个样本预测为的叶的索引。

cost_complexity_pruning_path(X, y[, …])

在最小代价复杂度修剪过程中计算修剪路径。

decision_path(X[, check_input])

返回树中的决策路径。

fit(X, y[, sample_weight, check_input, …])

从训练集(Xy)构建决策树分类器。

get_depth()

返回决策树的深度。

get_n_leaves()

返回决策树的叶数。

get_params([deep])

获取此估计器的参数。

predict(X[, check_input])

预测X的类或回归值。

predict_log_proba(X)

预测输入样本X的类对数概率。

predict_proba(X[, check_input])

预测输入样本X的类概率。

score(X, y[, sample_weight])

返回给定测试数据和标签的平均精度。

set_params(**params)

设置此估计器的参数。


9.3.2用散点图来分析鸢尾花数据

def iris_of_decision_tree():
       myutil = util()
       iris = datasets.load_iris()
       # 仅选前两个特征
       X = iris.data[:,:2]
       y = iris.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5,7]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"鸢尾花数据测试集(max_depth="+str(max_depth)+")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)
              myutil.plot_learning_curve(DecisionTreeClassifier(max_depth=max_depth),X,y,title)
              myutil.show_pic(title)


输出

鸢尾花数据测试集(max_depth=1):
64.29%
鸢尾花数据测试集(max_depth=1):
57.89%
鸢尾花数据测试集(max_depth=3):
83.93%
鸢尾花数据测试集(max_depth=3):
71.05%
鸢尾花数据测试集(max_depth=5):
85.71%
鸢尾花数据测试集(max_depth=5):
73.68%
鸢尾花数据测试集(max_depth=7):
88.39%
鸢尾花数据测试集(max_depth=7):
65.79%


max_depth=5的时候效果最好

image.png

image.png

image.png

image.png

image.png

image.png

image.pngimage.png


9.3.3用散点图分析红酒数据

def wine_of_decision_tree():
       myutil = util()
       wine = datasets.load_wine()
       # 仅选前两个特征
       X = wine.data[:,:2]
       y = wine.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"红酒数据测试集(max_depth="+str(max_depth)+")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)


输出

红酒数据测试集(max_depth=1):
69.17%
红酒数据测试集(max_depth=1):
64.44%
红酒数据测试集(max_depth=3):
87.97%
红酒数据测试集(max_depth=3):
80.00%
红酒数据测试集(max_depth=5):
90.98%
红酒数据测试集(max_depth=5):
80.00%
红酒数据测试集(max_depth=7):
96.99%
红酒数据测试集(max_depth=7):
73.33%


max_depth=5的时候效果最好;max_depth=7

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png


9.3.4用散点图分析乳腺癌数据

def wine_of_decision_tree():
       myutil = util()
       wine = datasets.load_wine()
       # 仅选前两个特征
       X = wine.data[:,:2]
       y = wine.target
       X_train,X_test,y_train,y_test = train_test_split(X, y)
       for max_depth in [1,3,5]:
              clf = DecisionTreeClassifier(max_depth=max_depth)
              clf.fit(X_train,y_train)
              title=u"乳腺癌数据测试集(max_depth="+str(max_depth)+")"
              myutil.print_scores(clf,X_train,y_train,X_test,y_test,title)
              myutil.draw_scatter(X,y,clf,title)


输出

乳腺癌数据测试集(max_depth=1):
90.38%
乳腺癌数据测试集(max_depth=1):
85.31%
乳腺癌数据测试集(max_depth=3):
91.08%
乳腺癌数据测试集(max_depth=3):
86.01%
乳腺癌数据测试集(max_depth=5):
93.43%
乳腺癌数据测试集(max_depth=5):
86.71%
乳腺癌数据测试集(max_depth=7):
96.95%
乳腺癌数据测试集(max_depth=7):
86.01%


max_depth=5的时候效果最好

image.png

image.png

image.png

image.png

image.png

image.png

image.pngimage.png


目录
相关文章
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
Python数据科学实战:从Pandas到机器学习
Python数据科学实战:从Pandas到机器学习
|
3月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
179 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 数据采集 人工智能
探索机器学习:从理论到Python代码实践
【10月更文挑战第36天】本文将深入浅出地介绍机器学习的基本概念、主要算法及其在Python中的实现。我们将通过实际案例,展示如何使用scikit-learn库进行数据预处理、模型选择和参数调优。无论你是初学者还是有一定基础的开发者,都能从中获得启发和实践指导。
87 2
|
3月前
|
机器学习/深度学习 数据可视化 数据处理
掌握Python数据科学基础——从数据处理到机器学习
掌握Python数据科学基础——从数据处理到机器学习
69 0
|
3月前
|
机器学习/深度学习 数据采集 人工智能
机器学习入门:Python与scikit-learn实战
机器学习入门:Python与scikit-learn实战
102 0
|
3月前
|
机器学习/深度学习 数据采集 数据挖掘
Python在数据科学中的应用:从数据处理到模型训练
Python在数据科学中的应用:从数据处理到模型训练
|
3月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
192 6
|
1月前
|
机器学习/深度学习 人工智能 算法
机器学习算法的优化与改进:提升模型性能的策略与方法
机器学习算法的优化与改进:提升模型性能的策略与方法
260 13
机器学习算法的优化与改进:提升模型性能的策略与方法
|
1月前
|
机器学习/深度学习 算法 网络安全
CCS 2024:如何严格衡量机器学习算法的隐私泄露? ETH有了新发现
在2024年CCS会议上,苏黎世联邦理工学院的研究人员提出,当前对机器学习隐私保护措施的评估可能存在严重误导。研究通过LiRA攻击评估了五种经验性隐私保护措施(HAMP、RelaxLoss、SELENA、DFKD和SSL),发现现有方法忽视最脆弱数据点、使用较弱攻击且未与实际差分隐私基线比较。结果表明这些措施在更强攻击下表现不佳,而强大的差分隐私基线则提供了更好的隐私-效用权衡。
52 14
|
2月前
|
算法
PAI下面的gbdt、xgboost、ps-smart 算法如何优化?
设置gbdt 、xgboost等算法的样本和特征的采样率
90 2

热门文章

最新文章