Spark之获取GBT二分类函数的概率值

简介:   在Spark中,GBT(Gradient Boost Trees,提升树)函数用于实现机器学习中的提升树算法,目前仅支持二分类算法。

  在Spark中,GBT(Gradient Boost Trees,提升树)函数用于实现机器学习中的提升树算法,目前仅支持二分类算法。笔者在实际工作中需要获得其预测的概率值,无奈该函数没有相应的方法。
  经过笔者几天的奋斗,终于找到了解决之道。下面将分享在Spark中如何获取GBT二分类函数的概率值的思路。
  首先,查看GBT函数的Scala源代码,其中的predict函数如下:
  predict函数
  其中的prediction值是我们计算概率值所需要的,prediction的值为_treePredictions(向量)与_treeWeights(向量)的点积,numTrees为GBTClassifier所使用的树的数量。_treePredictions为每棵决策树的预测值组成的向量,_treeWeights为每颗树的权重组成的向量。
  那么该如何计算_treePredictions与_treeWeights呢?
  _treeWeights的计算可直接调用GBTClassifier中的treeWeights方法,输出的数据类型为list。以下为例子:(该例子的数据集和GBT测试代码见附录)
  _treeWeights
  接下来讲述_treePredictions的计算。
  首先可以利用toDebugString方法可以知道每棵决策树的具体情形,以下为例子:
  toDebugString
  利用trees方法可知道决策树的情形,trees方法输出为list.
  trees
  取该list的下标,可以获得每颗树的模型,类型为pyspark.ml.regression.DecisionTreeRegressionMdel.
  每颗树的模型
  这样就可以利用DecisionTreeRegressionMdel的transfrom(data)得到每棵决策树的预测值,其中data为测试数据集。
  最后,我们利用以下公式就能得到GBT函数预测的概率值:
  

probability=1eprediction,

其中,其中prediction为Scala源代码中的prediction值,它的计算方法如上所述.
  当probability > 0.5时,分类为1,否则为0.
  这样,我们就能在Spark中获取GBT二分类函数的概率值了。读者可以结合以上的计算思路和实际工作,自己来编写预测概率值的代码啦~~
附录1:Spark的GBTClassifier的scala源代码网址: https://fossies.org/linux/spark/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
附录2:测试数据集:
测试数据集
附录3:测试代码:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import GBTClassifier
df = spark.createDataFrame([(1.0,Vectors.dense(1,2,3)), (0.0,Vectors.dense(4,5,6)), (1.0,Vectors.dense(1,2,4)), (0.0,Vectors.dense(3,5,6)), (1.0,Vectors.dense(1,3,2)), (1.0,Vectors.dense(1,3,4)),], ["label", "features"])
stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
si_model = stringIndexer.fit(df)
td = si_model.transfrom(df)
gbt = GBTClassifier(maxIter=5, maxDeth=2, labelCol="indexed", seed=42)
model = gbt.fit(td)



参考网址:
1.spark如何获得分类概率: http://blog.csdn.net/seu_yang/article/details/52118683
2.Predicting probabilities of classes in case of Gradient Boosting Trees in Spark using the tree output: https://stackoverflow.com/questions/37303855/predicting-probabilities-of-classes-in-case-of-gradient-boosting-trees-in-spark

目录
相关文章
|
存储 分布式计算 并行计算
大数据Spark RDD 函数 1
大数据Spark RDD 函数
114 0
spark3.5.1中内置函数大全
spark3.5.1中内置函数大全
|
5月前
|
SQL 分布式计算 数据处理
MaxCompute操作报错合集之使用Spark查询时函数找不到的原因是什么
MaxCompute是阿里云提供的大规模离线数据处理服务,用于大数据分析、挖掘和报表生成等场景。在使用MaxCompute进行数据处理时,可能会遇到各种操作报错。以下是一些常见的MaxCompute操作报错及其可能的原因与解决措施的合集。
|
分布式计算 Java Spark
图解Spark Graphx实现顶点关联邻接顶点的collectNeighbors函数原理
图解Spark Graphx实现顶点关联邻接顶点的collectNeighbors函数原理
54 0
|
7月前
|
SQL 分布式计算 Spark
Spark【Spark SQL(四)UDF函数和UDAF函数】
Spark【Spark SQL(四)UDF函数和UDAF函数】
|
分布式计算 大数据 数据挖掘
大数据Spark RDD 函数 2
大数据Spark RDD 函数
105 0
|
SQL JSON 分布式计算
spark2 sql读取数据源编程学习样例2:函数实现详解
spark2 sql读取数据源编程学习样例2:函数实现详解
95 0
spark2 sql读取数据源编程学习样例2:函数实现详解
|
SQL 存储 分布式计算
Spark强大的函数扩展功能
Spark强大的函数扩展功能
|
分布式计算 Java Linux
【Spark 3.0-JavaAPI-pom】体验JavaRDD函数封装变化
【Spark 3.0-JavaAPI-pom】体验JavaRDD函数封装变化
196 0
【Spark 3.0-JavaAPI-pom】体验JavaRDD函数封装变化
|
分布式计算 Java Scala
一天学完spark的Scala基础语法教程四、方法与函数(idea版本)
一天学完spark的Scala基础语法教程四、方法与函数(idea版本)
107 0
一天学完spark的Scala基础语法教程四、方法与函数(idea版本)