决策树的介绍
决策树是一种常见的分类模型,在金融风控、医疗辅助诊断等诸多行业具有较为广泛的应用。决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。例如在婚恋市场中,女方通常会先询问男方是否有房产,如果有房产再了解是否有车产,如果有车产再看是否有稳定工作……最后得出是否要深入了解的判断。
决策树的主要优点:
- 具有很好的解释性,模型可以生成可以理解的规则。
- 可以发现特征的重要程度。
- 模型的计算复杂度较低。
决策树的主要缺点:
- 模型容易过拟合,需要采用减枝技术处理。
- 不能很好利用连续型特征。
- 预测能力有限,无法达到其他强监督模型效果。
- 方差较高,数据分布的轻微改变很容易造成树结构完全不同。
由于决策树模型中自变量与因变量的非线性关系以及决策树简单的计算方法,使得它成为集成学习中最为广泛使用的基模型。梯度提升树,XGBoost以及LightGBM等先进的集成模型都采用了决策树作为基模型,在广告计算、CTR预估、金融风控等领域大放异彩 ,同时决策树在一些明确需要可解释性或者提取分类规则的场景中被广泛应用,而其他机器学习模型在这一点很难做到。例如在医疗辅助系统中,为了方便专业人员发现错误,常常将决策树算法用于辅助病症检测。
决策树的应用
通过sklearn实现决策树分类
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets iris = datasets.load_iris() X = iris.data[:,2:] y = iris.target plt.scatter(X[y==0,0],X[y==0,1]) plt.scatter(X[y==1,0],X[y==1,1]) plt.scatter(X[y==2,0],X[y==2,1]) plt.show()
from sklearn.tree import DecisionTreeClassifier tree = DecisionTreeClassifier(max_depth=2,criterion="entropy") tree.fit(X,y)
依据模型绘制决策树的决策边界
def plot_decision_boundary(model,axis): x0,x1 = np.meshgrid( np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1), np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1) ) X_new = np.c_[x0.ravel(),x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) from matplotlib.colors import ListedColormap custom_map = ListedColormap(["#EF9A9A","#FFF59D","#90CAF9"]) plt.contourf(x0,x1,zz,linewidth=5,cmap=custom_map) plot_decision_boundary(tree,axis=[0.5,7.5,0,3]) plt.scatter(X[y==0,0],X[y==0,1]) plt.scatter(X[y==1,0],X[y==1,1]) plt.scatter(X[y==2,0],X[y==2,1]) plt.show()
实战:
Step: 库函数导入
import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.tree import DecisionTreeClassifier from sklearn import tree
Step: 训练模型
x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]]) y_label = np.array([0, 1, 0, 1, 0, 1]) tree_clf = DecisionTreeClassifier() tree_clf = tree_clf.fit(x_fearures, y_label)
Step: 数据和模型可视化
plt.figure() plt.scatter(x_fearures[:,0],x_fearures[:,1], c=y_label, s=50, cmap='viridis') plt.title('Dataset') plt.show() import graphviz dot_data = tree.export_graphviz(tree_clf, out_file=None) graph = graphviz.Source(dot_data) graph.render("pengunis")
Step:模型预测
x_fearures_new1 = np.array([[0, -1]]) x_fearures_new2 = np.array([[2, 1]]) y_label_new1_predict = tree_clf.predict(x_fearures_new1) y_label_new2_predict = tree_clf.predict(x_fearures_new2) print('The New point 1 predict class:\n',y_label_new1_predict) print('The New point 2 predict class:\n',y_label_new2_predict)
机器学习算法决策树(二)+https://developer.aliyun.com/article/1544103?spm=a2c6h.13148508.setting.16.1fa24f0eRBJGs5