TensorFlow决策森林构建GBDT(Python)

简介: TensorFlow决策森林构建GBDT(Python)

一、Deep Learning is Not All You Need


尽管神经网络在图像识别、自然语言等很多领域大放异彩,但回到表格数据的数据挖掘任务中,树模型才是低调王者,如论文《Tabular Data: Deep Learning is Not All You Need》提及的:



深度学习可能不是解决所有机器学习问题的灵丹妙药,通过树模型在处理表格数据时性能与神经网络相当(甚至优于神经网络),而且树模型易于训练使用,有较好的可解释性。


二、树模型的使用


对于决策树等模型的使用,通常是要到scikit-learn、xgboost、lightgbm等机器学习库调用, 这和深度学习库是独立割裂的,不太方便树模型与神经网络的模型融合。



一个好消息是,Google 开源了 TensorFlow 决策森林(TF-DF),为基于树的模型和神经网络提供统一的接口,可以直接用TensorFlow调用树模型。决策森林(TF-DF)简单来说就是用TensorFlow封装了常用的随机森林(RF)、梯度提升(GBDT)等算法,其底层算法是基于C++的 Yggdrasil 决策森林 (YDF)实现的。


三、TensorFlow构建GBDT实践


TF-DF安装很简单pip install -U tensorflow_decision_forests,有个遗憾是目前只支持Linux环境,如果本地用不了将代码复制到 Google Colab 试试~


  • 本例的数据集用的癌细胞分类的数据集,首先加载下常用的模块及数据集:


importnumpyasnp importpandasaspd importmatplotlib.pyplotasplt importtensorflowastf tf.random.set_seed(123) fromsklearnimportdatasets fromsklearn.model_selectionimporttrain_test_split fromsklearn.metricsimportprecision_score,recall_score,f1_score,roc_curve dataset_cancer=datasets.load_breast_cancer()#加载癌细胞数据集 #print(dataset_cancer['DESCR']) df=pd.DataFrame(dataset_cancer.data,columns=dataset_cancer.feature_names) df['label']=dataset_cancer.target print(df.shape) df.head()



  • 划分数据集,并简单做下数据EDA分析:


# holdout验证法:按3:7划分测试集训练集 x_train,x_test=train_test_split(df,test_size=0.3) # EDA分析:数据统计指标 x_train.describe(include='all')



  • 构建TensorFlow的GBDT模型:TD-DF 一个非常方便的地方是它不需要对数据进行任何预处理。它会自动处理数字和分类特征,以及缺失值,我们只需要将df转换为 TensorFlow 数据集,如下一些超参数设定:



模型方面的树的一些常规超参数,类似于scikit-learn的GBDT



此外,还有带有正则化(dropout、earlystop)、损失函数(focal-loss)、效率方面(goss基于梯度采样)等优化方法:



构建模型、编译及训练,一步到位:


#模型参数 model_tf=tfdf.keras.GradientBoostedTreesModel(loss="BINARY_FOCAL_LOSS") #模型训练 model_tf.compile() model_tf.fit(x=train_ds,validation_freq=0.1)


  • 评估模型效果


##模型评估 可以看到test的准确率已经都接近1,可以再那个困难的数据任务试试~ evaluation=model_tf.evaluate(test_ds,return_dict=True) probs=model_tf.predict(test_ds) fpr,tpr,_=roc_curve(x_test.label,probs) plt.plot(fpr,tpr) plt.title('ROCcurve') plt.xlabel('falsepositiverate') plt.ylabel('truepositiverate') plt.xlim(0,) plt.ylim(0,) plt.show() print(evaluation)


  • 模型解释性 GBDT等树模型还有另外一个很大的优势是解释性,这里TF-DF也有实现。模型情况及特征重要性可以通过print(model_tf.summary())打印出来,



特征重要性支持了几种不同的方法评估:


MEAN_MIN_DEPTH指标。平均最小深度越小,较低的值意味着大量样本是基于此特征进行分类的,变量越重要。



NUM_NODES指标。它显示了给定特征被用作分割的次数,类似split。此外还有其他指标就不一一列举了。



我们还可以打印出模型的具体决策的树结构,通过运行tfdf.model_plotter.plot_model_in_colab(model_tf, tree_idx=0,

max_depth=10),整个过程还是比较清晰的。



小结


基于TensorFlow的TF-DF的树模型方法,我们可以方便训练树模型(特别对于熟练TensorFlow框架的同学),更进一步,也可以与TensorFlow的神经网络模型做效果对比、树模型与神经网络模型融合、利用异构模型先特征表示学习再输入模型(如GBDT+DNN、DNN embedding+GBDT),进一步了解可见如下参考文献。

相关文章
|
1天前
|
数据挖掘 PyTorch TensorFlow
|
3天前
|
JSON 安全 数据安全/隐私保护
实战指南:Python中OAuth与JWT的完美结合,构建安全认证防线
【9月更文挑战第9天】当今互联网应用的安全性至关重要,尤其在处理用户数据和个人隐私时。OAuth 和 JWT 是两种广泛使用的认证机制,各具优势。本文探讨如何在 Python 中结合 OAuth 和 JSON Web Tokens (JWT) 构建安全可靠的认证系统。OAuth 允许第三方应用获取有限访问权限而不暴露用户密码;JWT 则是一种轻量级数据交换格式,用于安全传输信息。结合使用这两种技术,可以在确保安全性的同时简化认证流程。
9 4
|
3天前
|
机器学习/深度学习 算法 Python
从菜鸟到大师:一棵决策树如何引领你的Python机器学习之旅
【9月更文挑战第9天】在数据科学领域,机器学习如同璀璨明珠,吸引无数探索者。尤其对于新手而言,纷繁复杂的算法常让人感到迷茫。本文将以决策树为切入点,带您从Python机器学习的新手逐步成长为高手。决策树以其直观易懂的特点成为入门利器。通过构建决策树分类器并应用到鸢尾花数据集上,我们展示了其基本用法及效果。掌握决策树后,还需深入理解其工作原理,调整参数,并探索集成学习方法,最终将所学应用于实际问题解决中,不断提升技能。愿这棵智慧之树助您成为独当一面的大师。
13 3
|
2天前
|
数据采集 JavaScript 前端开发
构建你的首个Python网络爬虫
【9月更文挑战第8天】本文将引导你从零开始,一步步构建属于自己的Python网络爬虫。我们将通过实际的代码示例和详细的步骤解释,让你理解网络爬虫的工作原理,并学会如何使用Python编写简单的网络爬虫。无论你是编程新手还是有一定基础的开发者,这篇文章都将为你打开网络数据获取的新世界。
|
5天前
|
机器学习/深度学习 算法 Python
决策树下的智慧果实:Python机器学习实战,轻松摘取数据洞察的果实
【9月更文挑战第7天】当我们身处数据海洋,如何提炼出有价值的洞察?决策树作为一种直观且强大的机器学习算法,宛如智慧之树,引领我们在繁复的数据中找到答案。通过Python的scikit-learn库,我们可以轻松实现决策树模型,对数据进行分类或回归分析。本教程将带领大家从零开始,通过实际案例掌握决策树的原理与应用,探索数据中的秘密。
14 1
|
6天前
|
算法 程序员 Linux
Python编程入门:构建你的第一个程序
【9月更文挑战第4天】编程是现代技术发展的基石,而Python作为一门简洁、易学且功能强大的编程语言,已成为众多初学者的首选。本文将引导你通过一个简单的Python程序,探索编程世界的奥秘,并了解如何利用Python实现基本的算法逻辑。无论你是完全的新手还是希望巩固基础的开发者,这篇文章都将为你提供一个清晰的学习路径。从安装Python环境开始,到编写第一个程序,我们将一步步揭开编程的神秘面纱。
|
11天前
|
数据采集 JavaScript 前端开发
构建简易Python爬虫:抓取网页数据入门指南
【8月更文挑战第31天】在数字信息的时代,数据抓取成为获取网络资源的重要手段。本文将引导你通过Python编写一个简单的网页爬虫,从零基础到实现数据抓取的全过程。我们将一起探索如何利用Python的requests库进行网络请求,使用BeautifulSoup库解析HTML文档,并最终提取出有价值的数据。无论你是编程新手还是有一定基础的开发者,这篇文章都将为你打开数据抓取的大门。
|
2天前
|
机器学习/深度学习 数据挖掘 TensorFlow
从数据小白到AI专家:Python数据分析与TensorFlow/PyTorch深度学习的蜕变之路
【9月更文挑战第10天】从数据新手成长为AI专家,需先掌握Python基础语法,并学会使用NumPy和Pandas进行数据分析。接着,通过Matplotlib和Seaborn实现数据可视化,最后利用TensorFlow或PyTorch探索深度学习。这一过程涉及从数据清洗、可视化到构建神经网络的多个步骤,每一步都需不断实践与学习。借助Python的强大功能及各类库的支持,你能逐步解锁数据的深层价值。
9 0
|
11天前
|
Java 缓存 数据库连接
揭秘!Struts 2性能翻倍的秘诀:不可思议的优化技巧大公开
【8月更文挑战第31天】《Struts 2性能优化技巧》介绍了提升Struts 2 Web应用响应速度的关键策略,包括减少配置开销、优化Action处理、合理使用拦截器、精简标签库使用、改进数据访问方式、利用缓存机制以及浏览器与网络层面的优化。通过实施这些技巧,如懒加载配置、异步请求处理、高效数据库连接管理和启用GZIP压缩等,可显著提高应用性能,为用户提供更快的体验。性能优化需根据实际场景持续调整。
35 0
|
11天前
|
机器学习/深度学习 人工智能 TensorFlow
深度学习入门:使用Python和TensorFlow构建你的第一个神经网络
【8月更文挑战第31天】 本文是一篇面向初学者的深度学习指南,旨在通过简洁明了的语言引导读者了解并实现他们的第一个神经网络。我们将一起探索深度学习的基本概念,并逐步构建一个能够识别手写数字的简单模型。文章将展示如何使用Python语言和TensorFlow框架来训练我们的网络,并通过直观的例子使抽象的概念具体化。无论你是编程新手还是深度学习领域的新兵,这篇文章都将成为你探索这个激动人心领域的垫脚石。