TensorFlow决策森林构建GBDT(Python)

简介: TensorFlow决策森林构建GBDT(Python)

一、Deep Learning is Not All You Need


尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:



深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。


二、树模型的使用


对于决策树等模型的使用,通常是要到scikit-learn、xgboost、lightgbm等机器学习库调用, 这和深度学习库是独立割裂的,不太方便树模型与神经网络的模型融合。



一个好消息是,Google 开源了 TensorFlow 决策森林(TF-DF),为基于树的模型和神经网络提供统一的接口,可以直接用TensorFlow调用树模型。决策森林(TF-DF)简单来说就是用TensorFlow封装了常用的随机森林(RF)、梯度提升(GBDT)等算法,其底层算法是基于C++的 Yggdrasil 决策森林 (YDF)实现的。


三、TensorFlow构建GBDT实践


TF-DF安装很简单pip install -U tensorflow_decision_forests,有个遗憾是目前只支持Linux环境,如果本地用不了将代码复制到 Google Colab 试试~


  • 本例的数据集用的癌细胞分类的数据集,首先加载下常用的模块及数据集:


importnumpyasnp importpandasaspd importmatplotlib.pyplotasplt importtensorflowastf tf.random.set_seed(123) fromsklearnimportdatasets fromsklearn.model_selectionimporttrain_test_split fromsklearn.metricsimportprecision_score,recall_score,f1_score,roc_curve dataset_cancer=datasets.load_breast_cancer()#加载癌细胞数据集 #print(dataset_cancer['DESCR']) df=pd.DataFrame(dataset_cancer.data,columns=dataset_cancer.feature_names) df['label']=dataset_cancer.target print(df.shape) df.head()



  • 划分数据集,并简单做下数据EDA分析:


# holdout验证法:按3:7划分测试集训练集 x_train,x_test=train_test_split(df,test_size=0.3) # EDA分析:数据统计指标 x_train.describe(include='all')



  • 构建TensorFlow的GBDT模型:TD-DF 一个非常方便的地方是它不需要对数据进行任何预处理。它会自动处理数字和分类特征,以及缺失值,我们只需要将df转换为 TensorFlow 数据集,如下一些超参数设定:



模型方面的树的一些常规超参数,类似于scikit-learn的GBDT



此外,还有带有正则化(dropout、earlystop)、损失函数(focal-loss)、效率方面(goss基于梯度采样)等优化方法:



构建模型、编译及训练,一步到位:


#模型参数 model_tf=tfdf.keras.GradientBoostedTreesModel(loss="BINARY_FOCAL_LOSS") #模型训练 model_tf.compile() model_tf.fit(x=train_ds,validation_freq=0.1)


  • 评估模型效果


##模型评估 可以看到test的准确率已经都接近1,可以再那个困难的数据任务试试~ evaluation=model_tf.evaluate(test_ds,return_dict=True) probs=model_tf.predict(test_ds) fpr,tpr,_=roc_curve(x_test.label,probs) plt.plot(fpr,tpr) plt.title('ROCcurve') plt.xlabel('falsepositiverate') plt.ylabel('truepositiverate') plt.xlim(0,) plt.ylim(0,) plt.show() print(evaluation)


  • 模型解释性 GBDT等树模型还有另外一个很大的优势是解释性,这里TF-DF也有实现。模型情况及特征重要性可以通过print(model_tf.summary())打印出来,



特征重要性支持了几种不同的方法评估:


MEAN_MIN_DEPTH指标。平均最小深度越小,较低的值意味着大量样本是基于此特征进行分类的,变量越重要。



NUM_NODES指标。它显示了给定特征被用作分割的次数,类似split。此外还有其他指标就不一一列举了。



我们还可以打印出模型的具体决策的树结构,通过运行tfdf.model_plotter.plot_model_in_colab(model_tf, tree_idx=0,

max_depth=10),整个过程还是比较清晰的。



小结


基于TensorFlow的TF-DF的树模型方法,我们可以方便训练树模型(特别对于熟练TensorFlow框架的同学),更进一步,也可以与TensorFlow的神经网络模型做效果对比、树模型与神经网络模型融合、利用异构模型先特征表示学习再输入模型(如GBDT+DNN、DNN embedding+GBDT),进一步了解可见如下参考文献。

相关文章
|
2月前
|
人工智能 JavaScript API
零基础构建MCP服务器:TypeScript/Python双语言实战指南
作为一名深耕技术领域多年的博主摘星,我深刻感受到了MCP(Model Context Protocol)协议在AI生态系统中的革命性意义。MCP作为Anthropic推出的开放标准,正在重新定义AI应用与外部系统的交互方式,它不仅解决了传统API集成的复杂性问题,更为开发者提供了一个统一、安全、高效的连接框架。在过去几个月的实践中,我发现许多开发者对MCP的概念理解透彻,但在实际动手构建MCP服务器时却遇到了各种技术壁垒。从环境配置的细节问题到SDK API的深度理解,从第一个Hello World程序的调试到生产环境的部署优化,每一个环节都可能成为初学者的绊脚石。因此,我决定撰写这篇全面的实
534 67
零基础构建MCP服务器:TypeScript/Python双语言实战指南
|
2月前
|
机器学习/深度学习 算法 量子技术
GQNN框架:让Python开发者轻松构建量子神经网络
为降低量子神经网络的研发门槛并提升其实用性,本文介绍一个名为GQNN(Generalized Quantum Neural Network)的Python开发框架。
56 4
GQNN框架:让Python开发者轻松构建量子神经网络
|
1月前
|
人工智能 自然语言处理 安全
Python构建MCP服务器:从工具封装到AI集成的全流程实践
MCP协议为AI提供标准化工具调用接口,助力模型高效操作现实世界。
396 1
|
1月前
|
数据采集 数据可视化 API
驱动业务决策:基于Python的App用户行为分析与可视化方案
驱动业务决策:基于Python的App用户行为分析与可视化方案
|
3月前
|
数据采集 数据可视化 JavaScript
用Python采集CBC新闻:如何借助海外代理IP构建稳定采集方案
本文介绍了如何利用Python技术栈结合海外代理IP采集加拿大CBC新闻数据。内容涵盖使用海外代理IP的必要性、青果代理IP的优势、实战爬取流程、数据清洗与可视化分析方法,以及高效构建大规模新闻采集方案的建议。适用于需要获取国际政治经济动态信息的商业决策、市场预测及学术研究场景。
|
3月前
|
数据采集 Web App开发 自然语言处理
利用Python构建今日头条搜索结果的可视化图表
利用Python构建今日头条搜索结果的可视化图表
|
6月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
378 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
6月前
|
机器学习/深度学习 设计模式 API
Python 高级编程与实战:构建 RESTful API
本文深入探讨了使用 Python 构建 RESTful API 的方法,涵盖 Flask、Django REST Framework 和 FastAPI 三个主流框架。通过实战项目示例,详细讲解了如何处理 GET、POST 请求,并返回相应数据。学习这些技术将帮助你掌握构建高效、可靠的 Web API。
|
6月前
|
机器学习/深度学习 设计模式 测试技术
Python 高级编程与实战:构建自动化测试框架
本文深入探讨了Python中的自动化测试框架,包括unittest、pytest和nose2,并通过实战项目帮助读者掌握这些技术。文中详细介绍了各框架的基本用法和示例代码,助力开发者快速验证代码正确性,减少手动测试工作量。学习资源推荐包括Python官方文档及Real Python等网站。

热门文章

最新文章

推荐镜像

更多