TensorFlow决策森林构建GBDT(Python)

简介: TensorFlow决策森林构建GBDT(Python)

一、Deep Learning is Not All You Need


尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:



深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。


二、树模型的使用


对于决策树等模型的使用,通常是要到scikit-learn、xgboost、lightgbm等机器学习库调用, 这和深度学习库是独立割裂的,不太方便树模型与神经网络的模型融合。



一个好消息是,Google 开源了 TensorFlow 决策森林(TF-DF),为基于树的模型和神经网络提供统一的接口,可以直接用TensorFlow调用树模型。决策森林(TF-DF)简单来说就是用TensorFlow封装了常用的随机森林(RF)、梯度提升(GBDT)等算法,其底层算法是基于C++的 Yggdrasil 决策森林 (YDF)实现的。


三、TensorFlow构建GBDT实践


TF-DF安装很简单pip install -U tensorflow_decision_forests,有个遗憾是目前只支持Linux环境,如果本地用不了将代码复制到 Google Colab 试试~


  • 本例的数据集用的癌细胞分类的数据集,首先加载下常用的模块及数据集:


importnumpyasnp importpandasaspd importmatplotlib.pyplotasplt importtensorflowastf tf.random.set_seed(123) fromsklearnimportdatasets fromsklearn.model_selectionimporttrain_test_split fromsklearn.metricsimportprecision_score,recall_score,f1_score,roc_curve dataset_cancer=datasets.load_breast_cancer()#加载癌细胞数据集 #print(dataset_cancer['DESCR']) df=pd.DataFrame(dataset_cancer.data,columns=dataset_cancer.feature_names) df['label']=dataset_cancer.target print(df.shape) df.head()



  • 划分数据集,并简单做下数据EDA分析:


# holdout验证法:按3:7划分测试集训练集 x_train,x_test=train_test_split(df,test_size=0.3) # EDA分析:数据统计指标 x_train.describe(include='all')



  • 构建TensorFlow的GBDT模型:TD-DF 一个非常方便的地方是它不需要对数据进行任何预处理。它会自动处理数字和分类特征,以及缺失值,我们只需要将df转换为 TensorFlow 数据集,如下一些超参数设定:



模型方面的树的一些常规超参数,类似于scikit-learn的GBDT



此外,还有带有正则化(dropout、earlystop)、损失函数(focal-loss)、效率方面(goss基于梯度采样)等优化方法:



构建模型、编译及训练,一步到位:


#模型参数 model_tf=tfdf.keras.GradientBoostedTreesModel(loss="BINARY_FOCAL_LOSS") #模型训练 model_tf.compile() model_tf.fit(x=train_ds,validation_freq=0.1)


  • 评估模型效果


##模型评估 可以看到test的准确率已经都接近1,可以再那个困难的数据任务试试~ evaluation=model_tf.evaluate(test_ds,return_dict=True) probs=model_tf.predict(test_ds) fpr,tpr,_=roc_curve(x_test.label,probs) plt.plot(fpr,tpr) plt.title('ROCcurve') plt.xlabel('falsepositiverate') plt.ylabel('truepositiverate') plt.xlim(0,) plt.ylim(0,) plt.show() print(evaluation)


  • 模型解释性 GBDT等树模型还有另外一个很大的优势是解释性,这里TF-DF也有实现。模型情况及特征重要性可以通过print(model_tf.summary())打印出来,



特征重要性支持了几种不同的方法评估:


MEAN_MIN_DEPTH指标。平均最小深度越小,较低的值意味着大量样本是基于此特征进行分类的,变量越重要。



NUM_NODES指标。它显示了给定特征被用作分割的次数,类似split。此外还有其他指标就不一一列举了。



我们还可以打印出模型的具体决策的树结构,通过运行tfdf.model_plotter.plot_model_in_colab(model_tf, tree_idx=0,

max_depth=10),整个过程还是比较清晰的。



小结


基于TensorFlow的TF-DF的树模型方法,我们可以方便训练树模型(特别对于熟练TensorFlow框架的同学),更进一步,也可以与TensorFlow的神经网络模型做效果对比、树模型与神经网络模型融合、利用异构模型先特征表示学习再输入模型(如GBDT+DNN、DNN embedding+GBDT),进一步了解可见如下参考文献。

相关文章
|
2月前
|
机器学习/深度学习 数据挖掘 Python
Python编程入门——从零开始构建你的第一个程序
【10月更文挑战第39天】本文将带你走进Python的世界,通过简单易懂的语言和实际的代码示例,让你快速掌握Python的基础语法。无论你是编程新手还是想学习新语言的老手,这篇文章都能为你提供有价值的信息。我们将从变量、数据类型、控制结构等基本概念入手,逐步过渡到函数、模块等高级特性,最后通过一个综合示例来巩固所学知识。让我们一起开启Python编程之旅吧!
|
1月前
|
数据采集 分布式计算 大数据
构建高效的数据管道:使用Python进行ETL任务
在数据驱动的世界中,高效地处理和移动数据是至关重要的。本文将引导你通过一个实际的Python ETL(提取、转换、加载)项目,从概念到实现。我们将探索如何设计一个灵活且可扩展的数据管道,确保数据的准确性和完整性。无论你是数据工程师、分析师还是任何对数据处理感兴趣的人,这篇文章都将成为你工具箱中的宝贵资源。
|
1月前
|
机器学习/深度学习 人工智能 算法
深度学习入门:用Python构建你的第一个神经网络
在人工智能的海洋中,深度学习是那艘能够带你远航的船。本文将作为你的航标,引导你搭建第一个神经网络模型,让你领略深度学习的魅力。通过简单直观的语言和实例,我们将一起探索隐藏在数据背后的模式,体验从零开始创造智能系统的快感。准备好了吗?让我们启航吧!
80 3
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
145 5
|
2月前
|
数据采集 XML 存储
构建高效的Python网络爬虫:从入门到实践
本文旨在通过深入浅出的方式,引导读者从零开始构建一个高效的Python网络爬虫。我们将探索爬虫的基本原理、核心组件以及如何利用Python的强大库进行数据抓取和处理。文章不仅提供理论指导,还结合实战案例,让读者能够快速掌握爬虫技术,并应用于实际项目中。无论你是编程新手还是有一定基础的开发者,都能在这篇文章中找到有价值的内容。
|
2月前
|
JSON 前端开发 API
使用Python和Flask构建简易Web API
使用Python和Flask构建简易Web API
128 3
|
2月前
|
存储 API 数据库
使用Python和Flask构建简单的RESTful API
使用Python和Flask构建简单的RESTful API
|
2月前
|
JSON 关系型数据库 测试技术
使用Python和Flask构建RESTful API服务
使用Python和Flask构建RESTful API服务
|
2月前
|
机器学习/深度学习 数据采集 搜索推荐
利用Python和机器学习构建电影推荐系统
利用Python和机器学习构建电影推荐系统
133 1
|
2月前
|
数据采集 数据可视化 数据挖掘
掌握Python数据分析,解锁数据驱动的决策能力
掌握Python数据分析,解锁数据驱动的决策能力
下一篇
开通oss服务