最近,深入研究了一下数据挖掘竞赛神器——XGBoost的算法原理和模型数据结构

简介: 从事数据挖掘相关工作的人肯定都知道XGBoost算法,这个曾经闪耀于数据挖掘竞赛的一代神器,是2016年由陈天齐大神所提出来的经典算法。本质上来讲,XGBoost算作是对GBDT算法的一种优化实现,但除了在集成算法理念层面的传承,具体设计细节其实还是有很大差别的。最近深入学习了一下,并简单探索了底层设计的数据结构,不禁感慨算法之精妙!聊作总结,以资后鉴!

640.png

2016年,陈天齐受邀参加关于XGBoost的分享会


XGBoost是机器学习中的一种集成算法,按照三大集成流派来划分,属于Boosting流派。Boosting流派也是集成算法中最为活跃和强大的流派,除了XGBoost之外,前有Adaboost和GBDT,后有LightGBM和CatBoost。当然,LightGBM和CatBoost与XGBoost一般都被视作是GBDT的改良和优化实现。


XGBoost算法原理其实已经非常成熟且完备,网络上关于这方面的分享也不计其数,所以本文也不想重复前人的工作对其长篇大论,而将写作目的聚焦如下:一是从解释公式的角度分享个人关于XGBoost算法原理的理解;二是简单探究下XGBoost底层的数据结构设计。其中后者类似于前期推文《数据科学:Sklearn中的决策树,底层是如何设计和存储的?》的定位。


01 学习公式推导,理解原理


为了理解XGBoost的算法原理,下面主要分享与该算法相关的5个主要公式的推导过程。所有公式主要源自XGBoost官方文档(https://xgboost.readthedocs.io/en/latest),查阅相关论文也可以找得到。


公式1——Boosting集成算法的加法模型:


微信截图_20220528072133.png


该公式非常简单,但对于理解后面的公式很重要。XGBoost作为一个Boosting算法,其集成的思路遵循典型的加法模型,即集成算法的输出等于各基学习器的输出之和(回归问题用求和很好理解,分类问题时其实是在拟合logloss,类似于逻辑回归中的做法,后文有提到),上式中fk(x)代表单个基学习器,K表示基学习器的个数。如下图中含有2个基学习器的集成模型为例,"儿子"的集成模型输出为两个基学习器的各自输出之和,即2+0.9=2.9;"爷爷"的集成模型输出为两个基学习器的输出之和,即-1-0.9=-1.9。


640.png


公式2——XGBoost中基学习器的目标函数


微信截图_20220528073120.png


机器学习有三要素:模型、策略和算法。其中策略则包含了如何界定或评估模型的好坏,换句话说就涉及到定义损失函数。与GBDT中不断拟合残差不同,XGBoost在其基础上增加了模型的结构损失,可理解为模型学习的代价;而残差则对应经验损失,可理解为模型学习能力的差距和缺失。


在上述目标函数(目标函数是比损失函数更为high-level的术语,一般包括两部分:目标函数=损失+正则项,越小越好。除了目标函数和损失函数之外,还有一个相关的术语叫代价函数,某种意义上代价函数可近似理解为损失函数。)中,求和的第一部分定义了当前模型训练结果与真实值的差距大小,称作经验风险,具体度量方法取决于相应的损失函数定义,典型的损失函数为:回归问题对应MSE损失,分类问题对应logloss损失;求和的第二部分体现了模型的结构风险,影响模型的泛华能力。如果基学习器选择决策树,那么这里的结构风险定义为:


微信截图_20220528073150.png


其中,γ和λ均为正则项系数,T为决策树中叶子节点的个数。注意,这里是叶子节点的个数,而非决策树中节点的数量(CART决策树进行CCP后剪枝时,用到的正则项是计算所有节点个数)。另外,这是一般介绍XGBoost原理时的公式,也是陈天齐最早论文中的写法,在Python的xgboost工具包中,模型初始化参数中除了与这两个参数对应的gamma和reg_lambda之外,还有reg_alpha参数,表示的一阶正则项,此时可写作:


微信截图_20220528073213.png


公式3——XGBoost中的Taylor二阶展开近似


微信截图_20220528073225.png


理解:这个Taylor二阶展开近似可谓是XGBoost的灵魂所在,也是最能体现其相较于GBDT的玄妙和强大之处。为了更好的解释上述近似,首先给出通常意义下的Taylor展开式:


微信截图_20220528073239.png


当然,上述公式也只是展开近似到了二阶,只要f(x)无限可导,则可有更高阶的近似。在XGBoost中,应用Taylor二阶展开近似其实是只对模型的经验风险部分,也就是公式2中第一部分求和的每个子项。这里,再次给出单个的表达式:


微信截图_20220528073249.png


在上式中,下标i对应的是训练集中的第i个样本,上角标t和t-1对应的集成算法中的第t轮和第t-1轮。那么进一步地,谁是f,谁是x,谁是△x呢?这里,f就是个函数记号,对应的是损失函数中的l,重点需要理解x和△x。


在模型训练时,训练数据集其实是确定的,在每个基学习器中都是那一套固定值,所以yi自然也不例外,在上述公式中就可看做是常数。在集成模型训练的第t轮,此时模型训练的目的是基于前t-1轮的训练结果(此时已经确定)来寻找最优的第t轮结果,以此来得到当前可能的最小损失。


实际上,在集成学习中,第一个基学习器往往已经能够拟合出大部分的结果出来,例如在惯用的拟合年龄的例子中,假设要拟合的是100这个结果,那么很可能第一个基学习器的拟合结果是90,而后面的N-1个学习器只是在不断的修正这个残差:10。举这个例子的目的是想表达:在上述公式中前t-1轮的拟合结果y_hat其实对应的就是f(x+△x)中的x,而第t轮的拟合值则可视作是浮动变量△x。至此,照着Taylor展开式的样式,上述目标函数可具体展开为:


微信截图_20220528073301.png


其中gi和hi分别为一阶导和二阶导,分别写作如下:


微信截图_20220528073314.png


进一步地,记y为真实值,y_hat为拟合值,则对于回归问题,适用最为常用的MSE损失,则其loss函数及相应的一阶导和二阶导分别为:


微信截图_20220528073323.png


而对于分类问题,以二分类问题为例,XGBoost中默认的损失函数为logloss,相应的loss函数及对应一阶导和二阶导分别为:


微信截图_20220528073646.png


公式4——决策树中的最优叶子权重求解


微信截图_20220528073704.png


XGBoost理论上可以支持任何基学习器,但其实最为常用的还是使用决策树,Python中的xgboost工具库也是默认以gbtree作为基学习器。在决策树中,第t轮训练得到的最优决策树实际上就是寻求最优的叶子权重的过程,所以理解这个最优的叶子权重尤为重要。


好在上述两个公式的求解非常简单易懂,甚至说是初中的数学知识范畴,可比SVM中的拉格朗日对偶问题容易理解多了。首先看如下转换:


微信截图_20220528073729.png



第一步的约等号当然是来源于公式3中的Taylor二阶展开近似,只不过此时将常数部分省略而已,需要注意的是此时的∑求解是以样本为粒度的求解,即此时i为样本序号,n为样本总数。而在第二步的等号转换中,则是以叶子节点为粒度,将落在同一叶子节点的多个样本进行了聚合,此时落在同一叶子节点上的所有样本预测结果均为其叶子权重ωj,各个叶子节点内部的求和对应为内部的∑。


有了以上的近似展开和各叶子节点的汇聚,则可以引出如下公式:


微信截图_20220528073742.png


其中Gj和Hj分别为第j个叶子节点所有样本的一阶导和二阶导的求和,即:


微信截图_20220528073801.png


上述目标公式可看做是T个一元二次表达式的求和,其中每个一元二次表达式中的变量为ωj。显然,求解形如f(x)=ax^2+bx+c的最小值问题是一个初中阶段的数学问题,进而容易得出最优的ωj及此时对应的损失函数最小取值结果为:


微信截图_20220528073823.png


这里的一元二次函数一定存在最小值,因为其二次项的系数1/2(Hj+λ)一定是个正数!


公式5——决策树的分裂增益


微信截图_20220528073847.png


公式4解决的是叶子节点的最优权重问题,那么实际上是绕过了一个前置问题:即决策树的内部节点如何进行分裂?内部节点如何进行分裂其实可进一步细分为两个子问题:①选择哪个特征进行分裂?②以什么阈值划分左右子树?


第一个问题很好解决,最简单也是一直沿用至今的做法都是对所有特征逐一遍历,对比哪个特征最带来增益最大。而对于第二个问题,其实也是采用遍历寻优的方法来得到最优分裂阈值,至于如何遍历寻优,其实还可以进一步细分为两个问题:i)选择哪些候选分裂阈值?ii)如何度量哪个分裂阈值更优?


选择哪些候选分裂阈值就涉及到很多技巧,XGBoost和LightGBM都采用了直方图法来简化可能的最优分裂点候选值,这里涉及的细节还有很多,暂且不谈;而对于如何度量分裂阈值更优的问题,则刚好可以利用前面公式4中的结论——叶子节点在最优权重下的最小损失表达式。以此为基础,度量最优分裂阈值的流程是这样的:

  • 如果该节点不进行分裂,即将其视作一个叶子节点,可以得到当前的最小损失取值;
  • 对于选定的特征及阈值,将当前节点的所有样本切分为左右子树,进而可以得到左右子树对应的最小损失取值;

那么,从该节点直接作为叶子节点到将其分裂为左右两个子叶子节点是否会带来损失的降低呢?所以只需将分裂前后的损失相减即可!那么相减之后γT部分为什么变为-γ了呢?其实就是因为在分裂之前该部分的正则项对应1个叶子节点,而分裂之后则对应2个叶子节点,所以两部分的γT相减即为-γ。


以上,就是关于XGBoost中的几个核心公式推导环节的理解,相信理解了这5个公式就基本能够理解XGBoost是如何设计和实现的了。当然,XGBoost的强大和设计巧妙之处绝不止于上述算法原理,其实还有很多实用的技巧和优化,这也构成了XGBoost的scalable能力,具体可参考论文《XGBoost: A Scalable Tree Boosting System》。


02 查看源码,了解底层数据结构


第一部分主要介绍了XGBoost中的核心公式部分,下面简要分享一下XGBoost中的底层数据结构设计。之所以增加这部分工作,仍然是因为近期在做部分预研工作的需要,所以重点探究了一下XGBoost中底层是如何存储所有基学习器的,也就是各个决策树的训练结果。


为了了解XGBoost中是如何存储训练后的各个决策树,我们查看分类器的模型训练部分源码,经过简单查看就可以定位到如下代码:


640.png


也就是说,XGBClassifer模型训练后的结果应该是保存在_Booster属性中。


当然,上述查看的xgboost提供的sklearn类型接口,在其原生训练方法中,实际上是调用xgboost.train函数来实现的模型训练,此时无论是回归任务还是分类任务,都是调用的这个函数,只是通过目标函数的不同来区分不同的任务类型而已。


为了进一步查看这个_Booster属性,我们实际训练一个XGBoost二分类模型,运用如下简单代码示例:


from sklearn.datasets import load_iris
from xgboost import XGBClassifier
X, y = load_iris(return_X_y=True)
# 原生鸢尾花数据集是三分类,对其进行采用为二分类
X = X[y<2]
y = y[y<2]
xgb = XGBClassifier(use_label_encoder=False)
xgb.fit(X, y, eval_metric='logloss')


而后,通过dir属性查看一下这个_Booster的结果:


640.png


实际上,这个_Booster属性是xgboost中定义的一个类,上述结果也可直接查看xgboost中关于Booster类的定义。在上述dir结果中,有几个函数值得重点关注:


  • save_model:用于将xgboost模型训练结果存储为文件,而且xgboost非常友好的是在1.0.0版本以后,直接支持存储为json格式,这可比pickle格式什么的方便多了,大大增强可读性


640.png


  • load_model:有save_model就一定有load_model,二者是互逆操作,即load_model可将save_model的json文件结果读取为一个xgboost模型
  • dump_model:实际上,dump也有存储的含义,例如json中定义的读写函数就是load和dump。而在xgboost中,dump_model与save_model的区别在于:dump_model的存储结果是便于人类阅读,但该过程是单向的,即dump的结果不能再load回去;
  • get_dump:与dump_model类似,只不过不是存储为文件,而只是返回一个字符串;
  • trees_to_dataframe:含义非常明了,就是将训练后的所有树信息转化为一个dataframe。


这里,首先看下trees_to_dataframe的结果:


640.png


似乎从列名来推断,除了最后的Cover和Category两个字段含义不甚明了之外,其他字段的含义都非常清楚,所以也不再做过多解释。


之后,再探索一下save_model和dump_model的结果。既然dump_model的结果便于人类阅读,那么就首先查看这个结果:


640.png


这里截取了dump_model后的txt文件的三个决策树信息,可见dump_model的结果仅保留了各决策树的分裂相关信息,以booster[0]第一个决策树为例,该决策树有三个节点,节点的缩进关系表达了子节点对应关系,内部节点标号标识了所选分裂特征及对应阈值,但叶子节点后面的数值实际上并非是其权重。


而后,再探索一下save_model的json文件结果,首先查看整个json的结构关系:


640.png


其中,trees部分是一个含有100个item的列表,对应了100个基学习器的信息,进一步查看,类似于sklearn中定义的Array-based Tree Representation形式,这里的决策树各个节点信息仍然是Array-based,即各属性的第i个取值表示了相应的第i个节点的对应属性。主要字段及含义如下:


640.png


值得指出的是,经过对比left_children和right_children以及parents三个属性的取值,容易推断出xgboost中的决策树节点编号是采用的层级遍历,这与sklearn中采用的前序遍历是不同的。


以上,只是给出了xgboost中关于基学习器信息的一些简单简单探索,感兴趣的可以进一步多做尝试以及查看相应源码设计,希望能对理解xgboost的原理有所帮助。


640.png


目录
相关文章
|
1月前
|
机器学习/深度学习 算法 数据挖掘
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
103 4
|
3天前
|
算法 Java 数据库
理解CAS算法原理
CAS(Compare and Swap,比较并交换)是一种无锁算法,用于实现多线程环境下的原子操作。它通过比较内存中的值与预期值是否相同来决定是否进行更新。JDK 5引入了基于CAS的乐观锁机制,替代了传统的synchronized独占锁,提升了并发性能。然而,CAS存在ABA问题、循环时间长开销大和只能保证单个共享变量原子性等缺点。为解决这些问题,可以使用版本号机制、合并多个变量或引入pause指令优化CPU执行效率。CAS广泛应用于JDK的原子类中,如AtomicInteger.incrementAndGet(),利用底层Unsafe库实现高效的无锁自增操作。
理解CAS算法原理
|
5天前
|
存储 运维 监控
探索局域网电脑监控软件:Python算法与数据结构的巧妙结合
在数字化时代,局域网电脑监控软件成为企业管理和IT运维的重要工具,确保数据安全和网络稳定。本文探讨其背后的关键技术——Python中的算法与数据结构,如字典用于高效存储设备信息,以及数据收集、异常检测和聚合算法提升监控效率。通过Python代码示例,展示了如何实现基本监控功能,帮助读者理解其工作原理并激发技术兴趣。
46 20
|
29天前
|
数据采集 存储 算法
Python 中的数据结构和算法优化策略
Python中的数据结构和算法如何进行优化?
|
1月前
|
算法
数据结构之路由表查找算法(深度优先搜索和宽度优先搜索)
在网络通信中,路由表用于指导数据包的传输路径。本文介绍了两种常用的路由表查找算法——深度优先算法(DFS)和宽度优先算法(BFS)。DFS使用栈实现,适合路径问题;BFS使用队列,保证找到最短路径。两者均能有效查找路由信息,但适用场景不同,需根据具体需求选择。文中还提供了这两种算法的核心代码及测试结果,验证了算法的有效性。
103 23
|
1月前
|
算法 容器
令牌桶算法原理及实现,图文详解
本文介绍令牌桶算法,一种常用的限流策略,通过恒定速率放入令牌,控制高并发场景下的流量,确保系统稳定运行。关注【mikechen的互联网架构】,10年+BAT架构经验倾囊相授。
令牌桶算法原理及实现,图文详解
|
24天前
|
存储 人工智能 缓存
【AI系统】布局转换原理与算法
数据布局转换技术通过优化内存中数据的排布,提升程序执行效率,特别是对于缓存性能的影响显著。本文介绍了数据在内存中的排布方式,包括内存对齐、大小端存储等概念,并详细探讨了张量数据在内存中的排布,如行优先与列优先排布,以及在深度学习中常见的NCHW与NHWC两种数据布局方式。这些布局方式的选择直接影响到程序的性能,尤其是在GPU和CPU上的表现。此外,还讨论了连续与非连续张量的概念及其对性能的影响。
46 3
|
29天前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法与应用
探索人工智能中的强化学习:原理、算法与应用
|
28天前
|
并行计算 算法 测试技术
C语言因高效灵活被广泛应用于软件开发。本文探讨了优化C语言程序性能的策略,涵盖算法优化、代码结构优化、内存管理优化、编译器优化、数据结构优化、并行计算优化及性能测试与分析七个方面
C语言因高效灵活被广泛应用于软件开发。本文探讨了优化C语言程序性能的策略,涵盖算法优化、代码结构优化、内存管理优化、编译器优化、数据结构优化、并行计算优化及性能测试与分析七个方面,旨在通过综合策略提升程序性能,满足实际需求。
62 1
|
1月前
|
缓存 算法 网络协议
OSPF的路由计算算法:原理与应用
OSPF的路由计算算法:原理与应用
50 4