Python可视化决策树【Matplotlib/Graphviz】

简介:

决策树是一种流行的有监督学习方法。决策树的优势在于其既可以用于回归,也可以用于分类,不需要特征缩放,而且具有比较好的可解释性,容易将决策树可视化。可视化的决策树不仅是理解你的模型的好办法,也是向其他人介绍你的模型的运作机制的有利工具。因此掌握决策树可视化的方法对于数据分析工作者来说非常重要。

机器学习相关教程:TensorFlow实战 | 机器学习基础 | 深入浅出Flask | Python基础

在这个教程里,我们将学习以下内容:

  • 如何使用scikit-learn训练一个决策树模型
  • 如何使用Matplotlib将决策树可视化
  • 如何使用Graphviz将决策树可视化
  • 如何将随机森林或决策树包中的单个决策树可视化

教程的代码可以从这里下载。现在让我们开始吧。

1、用scikit-learn训练决策树模型

为了可视化决策树,我们首先需要用scikit-learn训练出一个决策树模型。

首先导入必要的Python库:

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from sklearn import tree

然后载入iris数据集。scikit-learn内置了Iris数据集,因此我们不需要从其他网站下载了。下面的Python代码载入Iris数据集:

import pandas as pd
from sklearn.datasets import load_irisdata = load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target

Iris数据集看起来是这样:

在这里插入图片描述

接下来我们将Iris数据集拆分为训练集和测试集:

X_train, X_test, Y_train, Y_test = train_test_split(df[data.feature_names], df['target'], random_state=0)

分割后的Iris数据集看起来是这样:

在这里插入图片描述

最后,我们采用scikit-learn经典的4步模式训练决策树模型:

# Step 1: Import the model you want to use
# This was already imported earlier in the notebook so commenting out
#from sklearn.tree import DecisionTreeClassifier
# Step 2: Make an instance of the Model
clf = DecisionTreeClassifier(max_depth = 2, 
                             random_state = 0)
# Step 3: Train the model on the data
clf.fit(X_train, Y_train)
# Step 4: Predict labels of unseen (test) data
# Not doing this step in the tutorial
# clf.predict(X_test)

2、使用Matplotlib将决策树可视化

从scikit-learn 版本21.0开始,可以使用scikit-learn的tree.plot_tree方法来利用matplotlib将决策树可视化,而不再需要依赖于难以安装的dot库。下面的Python代码展示了如何使用scikit-learn将决策树可视化:

tree.plot_tree(clf);

决策树可视化结果如下:
在这里插入图片描述

还可以添加一些额外的Python代码以便让绘制出的决策树具有更好的
可解读性,例如添加特征和分类名称:

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)
tree.plot_tree(clf,
               feature_names = fn, 
               class_names=cn,
               filled = True);
fig.savefig('imagename.png')

增加了更多信息的决策树可视化结果如下:

在这里插入图片描述

3、使用Graphviz将决策树可视化

下图是使用Graphviz得到的决策树可视化结果:

在这里插入图片描述

Graphviz是一个开源的图(Graph)可视化软件,采用抽象的图和网络来表示结构化的信息。在数据科学领域,Graphviz的一个用途就是实现决策树可视化。我将graphviz方法放在matplotlib方法之后,是因为这个软件用起来有点复杂。

为了将决策树可视化,首先需要创建一个dot文件来描述决策树,这个倒不难。问题在于使用Graphviz将dot文件转换为图形文件,例如png、jpg等等可能会有点难度。

有一些办法来降低graphviz的使用门槛,例如通过Anaconda安装python-graphviz、利用mac的homebrew安装grahpviz、利用官方提供的windows安装文件、或者使用在线转换器将决策树的dot文件转换为图形:

在这里插入图片描述
首先我们将决策树模型导出为dot文件:

tree.export_graphviz(clf,
                     out_file="tree.dot",
                     feature_names = fn, 
                     class_names=cn,
                     filled = True)

然后我们用conda安装graphviz:

conda install python-graphviz

现在就可以将决策树模型导出的dot文件转换为图形文件了:

dot -Tpng tree.dot -o tree.png

4、将决策树包或随机森林里的单个决策树可视化

决策树的一个缺点是通常其预测精度不够好。这部分原因在于其变化幅度比较大,对训练数据的不同拆分方式可能会生成截然不同的决策树模型。

在这里插入图片描述

上图可以表示决策树包或者随机森林模型之类的组合学习方法,通过将多个机器学习算法组合起来以期获得更好的预测性能。在这一部分,我们学习如何将这些组合模型中的单个决策树可视化。

首先还是使用scikit-learn来训练得到一个随机森林模型:

# Load the Breast Cancer (Diagnostic) Dataset
data = load_breast_cancer()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.target
# Arrange Data into Features Matrix and Target Vector
X = df.loc[:, df.columns != 'target']
y = df.loc[:, 'target'].values
# Split the data into training and testing sets
X_train, X_test, Y_train, Y_test = train_test_split(X, y, random_state=0)
# Random Forests in `scikit-learn` (with N = 100)
rf = RandomForestClassifier(n_estimators=100,
                            random_state=0)
rf.fit(X_train, Y_train)

现在我们可以将模型中的单个决策树可视化。首先还是使用matplotlib。下面的python代码将第1个决策树可视化:

fn=data.feature_names
cn=data.target_names
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=800)
tree.plot_tree(rf.estimators_[0],
               feature_names = fn, 
               class_names=cn,
               filled = True);
fig.savefig('rf_individualtree.png')

得到的这个决策树可视化结果如下:

在这里插入图片描述

你可以试着使用matplotlib的subplot来将你期望的多个决策树可视化。例如下面的Python代码将组合模型中的前5个决策树可视化:

# This may not the best way to view each estimator as it is smallfn=data.feature_names
cn=data.target_names
fig, axes = plt.subplots(nrows = 1,ncols = 5,figsize = (10,2), dpi=3000)for index in range(0, 5):
    tree.plot_tree(rf.estimators_[index],
                   feature_names = fn, 
                   class_names=cn,
                   filled = True,
                   ax = axes[index]);
    
    axes[index].set_title('Estimator: ' + str(index), fontsize = 11)fig.savefig('rf_5trees.png')

不过我个人不喜欢这么做,因为这看起来太费眼睛了:

在这里插入图片描述

5、教程小结

在这个教程里,我们学习了如何使用matplotlib和graphviz将scikit-learn训练得到的决策树可视化,也学习了如何将组合模型中的一个或多个决策树可视化,希望这有助于你的数据分析工作。


原文链接:决策树可视化 — 汇智网

目录
相关文章
|
25天前
|
数据可视化 Python
以下是一些常用的图表类型及其Python代码示例,使用Matplotlib和Seaborn库。
通过这些思维导图和分析说明表,您可以更直观地理解和选择适合的数据可视化图表类型,帮助更有效地展示和分析数据。
64 8
|
29天前
|
数据可视化 编译器 Python
Manim:数学可视化的强大工具 | python小知识
Manim(Manim Community Edition)是由3Blue1Brown的Grant Sanderson开发的数学动画引擎,专为数学和科学可视化设计。它结合了Python的灵活性与LaTeX的精确性,支持多领域的内容展示,能生成清晰、精确的数学动画,广泛应用于教育视频制作。安装简单,入门容易,适合教育工作者和编程爱好者使用。
210 7
|
2月前
|
存储 数据可视化 数据挖掘
使用Python进行数据分析和可视化
本文将引导你理解如何使用Python进行数据分析和可视化。我们将从基础的数据结构开始,逐步深入到数据处理和分析的方法,最后通过实际的代码示例来展示如何创建直观的数据可视化。无论你是初学者还是有经验的开发者,这篇文章都将为你提供有价值的见解和技巧。让我们一起探索数据的世界,发现隐藏在数字背后的故事!
|
2月前
|
机器学习/深度学习 数据可视化 数据挖掘
使用Python进行数据分析和可视化
【10月更文挑战第42天】本文将介绍如何使用Python进行数据分析和可视化。我们将从数据导入、清洗、探索性分析、建模预测,以及结果的可视化展示等方面展开讲解。通过这篇文章,你将了解到Python在数据处理和分析中的强大功能,以及如何利用这些工具来提升你的工作效率。
|
2月前
|
数据可视化 搜索推荐 Shell
Python与Plotly:B站每周必看榜单的可视化解决方案
Python与Plotly:B站每周必看榜单的可视化解决方案
|
2月前
|
移动开发 数据可视化 数据挖掘
利用Python实现数据可视化:以Matplotlib和Seaborn为例
【10月更文挑战第37天】本文旨在引导读者理解并掌握使用Python进行数据可视化的基本方法。通过深入浅出的介绍,我们将探索如何使用两个流行的库——Matplotlib和Seaborn,来创建引人入胜的图表。文章将通过具体示例展示如何从简单的图表开始,逐步过渡到更复杂的可视化技术,帮助初学者构建起强大的数据呈现能力。
|
2月前
|
数据可视化 JavaScript 前端开发
Python中交互式Matplotlib图表
【10月更文挑战第20天】Matplotlib 是 Python 中最常用的绘图库之一,但默认生成的图表是静态的。通过结合 mpld3 库,可以轻松创建交互式图表,提升数据可视化效果。本文介绍了如何使用 mpld3 在 Python 中创建交互式散点图、折线图和直方图,并提供了详细的代码示例和安装方法。通过添加插件,可以实现缩放、平移和鼠标悬停显示数据标签等交互功能。希望本文能帮助读者掌握这一强大工具。
84 5
|
2月前
|
数据采集 数据可视化 数据挖掘
使用Python进行数据分析和可视化
【10月更文挑战第33天】本文将介绍如何使用Python编程语言进行数据分析和可视化。我们将从数据清洗开始,然后进行数据探索性分析,最后使用matplotlib和seaborn库进行数据可视化。通过阅读本文,你将学会如何运用Python进行数据处理和可视化展示。