决策树
决策树是属于有监督机器学习的一种,起源非常早,符合直觉并且非常直观,模仿人类做决 策的过程,早期人工智能模型中有很多应用,现在更多的是使用基于决策树的一些集成学习 的算法。这章我们把决策树算法理解透彻非常有利于后面去学习集成学习。
特点
1. 可以处理非线性的问题
2. 可解释性强 没有θ
3. 模型简单,模型预测效率高 if else
4. 不容易显示的使用函数表达,不可微
决策树模型生成和预测
模型生成:通过大量数据生成一颗非常好的树,用这棵树来预测新来的数据
预测:来一条新数据,按照生成好的树的标准,落到某一个节点上
生成决策树所需分裂指标
常用分裂条件
Gini 系数(CART)
基尼系数是指国际上通用的、用以衡量一个国家或地区居民收入差距的常用指标。 基尼系数最大为“1”,最小等于“0”。基尼系数越接近 0 表明收入分配越是趋向平 等。国际惯例把 0.2 以下视为收入绝对平均,0.2-0.3 视为收入比较平均;0.3-0.4 视为收入 相对合理;0.4-0.5 视为收入差距较大,当基尼系数达到 0.5 以上时,则表示收入悬殊。 基尼指数最早由意大利统计与社会学家 Corrado Gini 在 1912 年提出。 其具体含义是指,在全部居民收入中,用于进行不平均分配的那部分收入所占的比例。 基尼系数最大为“1”,最小等于“0”。前者表示居民之间的收入分配绝对不平均,即 100% 的收入被一个单位的人全部占有了;而后者则表示居民之间的收入分配绝对平均,即人与人之间收入完全平等,没有任何差异。但这两种情况只是在理论上的绝对化形式,在实际生活 中一般不会出现。因此,基尼系数的实际数值只能介于 0~1 之间,基尼系数越小收入分配 越平均,基尼系数越大收入分配越不平均。国际上通常把 0.4 作为贫富差距的警戒线,大于 这一数值容易出现社会动荡。
信息增益(ID3)
在信息论里熵叫作信息量,即熵是对不确定性的度量。从控制论的角度来看,应叫不确定性。信息论的创始人香农在其著作《通信的数学理论》中提出了建立在概率统计模型上的 信息度量。他把信息定义为“用来消除不确定性的东西”。在信息世界,熵越高,则能传输 越多的信息,熵越低,则意味着传输的信息越少。还是举例说明,假设 Kathy 在买衣服的 时候有颜色,尺寸,款式以及设计年份四种要求,而 North 只有颜色和尺寸的要求,那么 在购买衣服这个层面上 Kathy 由于选择更多因而不确定性因素更大,最终 Kathy 所获取的 信息更多,也就是熵更大。所以信息量=熵=不确定性,通俗易懂。在叙述决策树时我们用 熵表示不纯度(Impurity)。
信息增益:分裂前的信息熵 减去 分裂后的信息熵 一个分裂导致的信息增益越大,代表这次分裂提升的纯度越高
信息增益率(C4.5)
对于多叉树,如果不限制分裂多少支,一次分裂就可以将信息熵降为 0,比如 ID3
如何平衡分裂情况与信息增益? 信息增益率:信息增益 除以 类别 本身的
MSE
用于回归树
经典决策树算法
决策树优缺点
优点
1. 决策过程接近人的思维习惯。
2. 模型容易解释,比线性模型具有更好的解释性。
3. 能清楚地使用图形化描述模型。
4. 处理定型特征比较容易。
缺点
1. 一般来说,决策树学习方法的准确率不如其他的模型。针对这种情况存在一些解决方案, 在后面的文章中为大家讲解。
2. 不支持在线学习。当有新样本来的时候,需要重建决策树。
3. 容易产生过拟合现象。
经典决策树算法
ID3 和 C4.5 比较
ID3(Iterative Dichotomiser 3,迭代二叉树 3 代)由 Ross Quinlan 于 1986 年提出。 1993 年,他对 ID3 进行改进设计出了 C4.5 算法。 我们已经知道 ID3 与 C4.5 的不同之处在于,ID3 根据信息增益选取特征构造决策树, 而 C4.5 则是以信息增益率为核心构造决策树。既然 C4.5 是在 ID3 的基础上改进得到的, 那么这两者的优缺点分别是什么? 使用信息增益会让 ID3 算法更偏向于选择值多的属性。信息增益反映给定一个条件后 不确定性减少的程度,必然是分得越细的数据集确定性更高,也就是信息熵越小,信息增益 越大。因此,在一定条件下,值多的属性具有更大的信息增益。而 C4.5 则使用信息增益率 选择属性。信息增益率通过引入一个被称作分裂信息(Split information)的项来惩罚取值较 多的属性,分裂信息用来衡量属性分裂数据的广度和均匀性。这样就改进了 ID3 偏向选择 值多属性的缺点。
相对于 ID3 只能处理离散数据,C4.5 还能对连续属性进行处理,具体步骤为:
1. 把需要处理的样本(对应根节点)或样本子集(对应子树)按照连续变量的大小从小到大进 行排序。 2. 假设该属性对应的不同的属性值一共有 N 个,那么总共有 N−1 个可能的候选分割阈值 点,每个候选的分割阈值点的值为上述排序后的属性值中两两前后连续元素的中点,根 据这个分割点把原来连续的属性分成 bool 属性。实际上可以不用检查所有 N−1 个分 割点。(连续属性值比较多的时候,由于需要排序和扫描,会使 C4.5 的性能有所下降。)
3. 用信息增益比率选择最佳划分。
C4.5 其他优点
1)在树的构造过程中可以进行剪枝,缓解过拟合;
2)能够对连续属性进行离散化处理(二 分法);
3)能够对缺失值进行处理; 缺点:构造树的过程需要对数据集进行多次顺序扫描和排序,导致算法低效; 刚才我们提到 信息增益对可取值数目较多的属性有所偏好;而信息增益率对可取值数目较 少的属性有所偏好!OK,两者结合一下就好了!
解决方法:先从候选属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。 而不是大家常说的直接选择信息增益率最高的属性!
代码实战对鸢尾花数据集分类
from six import StringIO import pandas as pd import numpy as np from sklearn import tree from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt import matplotlib as mpl import pydotplus iris = load_iris() data = pd.DataFrame(iris.data) data.columns = iris.feature_names data['Species'] = load_iris().target # 准备数据 x = data.iloc[:, 0:4] y = data.iloc[:, -1] x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42) # 训练决策树模型 tree_clf = DecisionTreeClassifier(max_depth=8, criterion='gini') tree_clf.fit(x_train, y_train) # 用测试集进行预测,得出精确率 y_test_hat = tree_clf.predict(x_test) print("acc score:", accuracy_score(y_test, y_test_hat)) print(tree_clf.feature_importances_) # 将决策树保存成图片 dot_data = StringIO() tree.export_graphviz( tree_clf, out_file=dot_data, feature_names=iris.feature_names[:], class_names=iris.target_names, rounded=True, filled=True ) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) graph.write_png('tree.png')
acc score: 1.0 [0.03575134 0. 0.88187037 0.08237829]
我们看出用测试集预测的精确度为100%,特征中,第三个特征的重要程度系数最大,说明它对决策分类起很重要的作用
我们看出决策树的最大深度为6,而我们在前面设置的是8,那么这个超参数应该设置为多少呢,我们通过画图来分析
depth = np.arange(1, 15) err_list = [] for d in depth: print(d) clf = DecisionTreeClassifier(criterion='gini', max_depth=d) clf.fit(x_train, y_train) y_test_hat = clf.predict(x_test) result = (y_test_hat == y_test) if d == 1: print(result) err = 1 - np.mean(result) print(100 * err) err_list.append(err) print(d, ' 错误率:%.2f%%' % (100 * err)) mpl.rcParams['font.sans-serif'] = ['SimHei'] plt.figure(facecolor='w') plt.plot(depth, err_list, 'ro-', lw=2) plt.xlabel('决策树深度', fontsize=15) plt.ylabel('错误率', fontsize=15) plt.title('决策树深度和过拟合', fontsize=18) plt.grid(True) plt.show()
从图中我们看出当决策树深度等于3的时候,错误率就以及接近0了