Pyspark Onv-vs-Rest分类器似乎没有提供概率。有没有办法做到这一点?
我在下面添加代码。我正在添加标准的多类分类器进行比较。
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
load data file.
inputData = spark.read.format("libsvm") \
.load("/data/mllib/sample_multiclass_classification_data.txt")
(train, test) = inputData.randomSplit([0.8, 0.2])
instantiate the base classifier.
lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)
instantiate the One Vs Rest Classifier.
ovr = OneVsRest(classifier=lr)
train the multiclass model.
ovrModel = ovr.fit(train)
lrm = lr.fit(train)
score the model on test data.
predictions = ovrModel.transform(test)
predictions2 = lrm.transform(test)
predictions.show(6)
predictions2.show(6)
我认为你不能访问概率(置信度)向量,因为它需要置信度的最大值并降低置信度向量。要进行测试,您可以复制该类并对其进行修改并删除.drop(accColName)
http://spark.apache.org/docs/2.0.1/api/python/_modules/pyspark/ml/classification.html
labelUDF = udf(
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
DoubleType())
return aggregatedDataset.withColumn(
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。