spark 针对决策树进行交叉验证

简介:
复制代码
from pyspark import SparkContext, SQLContext
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Load the data stored in LIBSVM format as a DataFrame.
data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

# Index labels, adding metadata to the label column.
# Fit on whole dataset to include all labels in index.
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)

# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])

# Train a DecisionTree model.
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")

# Chain indexers and tree in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])

# Train model.  This also runs the indexers.
model = pipeline.fit(trainingData)

# Make predictions.
predictions = model.transform(testData)

# Select example rows to display.
predictions.select("prediction", "indexedLabel", "features").show(5)

# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g " % (1.0 - accuracy))

treeModel = model.stages[2]
# summary only
print(treeModel)


#############################

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# Create ParamGrid for Cross Validation
paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01, 0.5, 2.0])
             .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
             .addGrid(lr.maxIter, [1, 5, 10])
             .build())
Copy to clipboardCopy
# Create 5-fold CrossValidator
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)

# Run cross validations
cvModel = cv.fit(trainingData)
# this will likely take a fair amount of time because of the amount of models that we're creating and testing

# Use test set here so we can measure the accuracy of our model on new data
predictions = cvModel.transform(testData)

# cvModel uses the best model found from the Cross Validation
# Evaluate best model
evaluator.evaluate(predictions)

#We can also access the model’s feature weights and intercepts easily


print 'Model Intercept: ', cvModel.bestModel.intercept
复制代码

 

 

复制代码
ML provides CrossValidator class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ??? 
val nFolds: Int = ???
val NumTrees: Int = ???
val metric: String = ???

val rf = new RandomForestClassifier()
  .setLabelCol("label")
  .setFeaturesCol("features")
  .setNumTrees(NumTrees)

val pipeline = new Pipeline().setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder().build() // No parameter search

val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
  .setMetricName(metric) 

val cv = new CrossValidator()
  // ml.Pipeline with ml.classification.RandomForestClassifier
  .setEstimator(pipeline)
  // ml.evaluation.MulticlassClassificationEvaluator
  .setEvaluator(evaluator) 
  .setEstimatorParamMaps(paramGrid)
  .setNumFolds(nFolds)

val model = cv.fit(trainingData) // trainingData: DataFrame
Using PySpark:

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    

pipeline = Pipeline(stages=[rf])

crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=numFolds)

model = crossval.fit(trainingData)
复制代码

 














本文转自张昺华-sky博客园博客,原文链接:http://www.cnblogs.com/bonelee/p/7803849.html,如需转载请自行联系原作者



相关文章
|
SQL 分布式计算 数据可视化
数据分享|Python、Spark SQL、MapReduce决策树、回归对车祸发生率影响因素可视化分析
数据分享|Python、Spark SQL、MapReduce决策树、回归对车祸发生率影响因素可视化分析
|
分布式计算 算法 测试技术
|
分布式计算 算法 数据挖掘
Spark 数据挖掘 - 利用决策树预测森林覆盖类型
Spark 数据挖掘 利用决策树预测森林覆盖类型
2195 0
|
分布式计算 安全 Spark
【Spark Summit East 2017】RISE实验室: 赋能智能实时决策
本讲义出自Ion Stoica在Spark Summit East 2017上的演讲,主要分享了其所在的加州大学伯克利分校的RISELab的研究方向,并讨论了一些RISE技术能够输出的应用方向。
2118 0
|
3月前
|
人工智能 分布式计算 大数据
大数据≠大样本:基于Spark的特征降维实战(提升10倍训练效率)
本文探讨了大数据场景下降维的核心问题与解决方案,重点分析了“维度灾难”对模型性能的影响及特征冗余的陷阱。通过数学证明与实际案例,揭示高维空间中样本稀疏性问题,并提出基于Spark的分布式降维技术选型与优化策略。文章详细展示了PCA在亿级用户画像中的应用,包括数据准备、核心实现与效果评估,同时深入探讨了协方差矩阵计算与特征值分解的并行优化方法。此外,还介绍了动态维度调整、非线性特征处理及降维与其他AI技术的协同效应,为生产环境提供了最佳实践指南。最终总结出降维的本质与工程实践原则,展望未来发展方向。
176 0
|
6月前
|
存储 分布式计算 Hadoop
从“笨重大象”到“敏捷火花”:Hadoop与Spark的大数据技术进化之路
从“笨重大象”到“敏捷火花”:Hadoop与Spark的大数据技术进化之路
279 79
|
10月前
|
分布式计算 大数据 Apache
ClickHouse与大数据生态集成:Spark & Flink 实战
【10月更文挑战第26天】在当今这个数据爆炸的时代,能够高效地处理和分析海量数据成为了企业和组织提升竞争力的关键。作为一款高性能的列式数据库系统,ClickHouse 在大数据分析领域展现出了卓越的能力。然而,为了充分利用ClickHouse的优势,将其与现有的大数据处理框架(如Apache Spark和Apache Flink)进行集成变得尤为重要。本文将从我个人的角度出发,探讨如何通过这些技术的结合,实现对大规模数据的实时处理和分析。
656 2
ClickHouse与大数据生态集成:Spark & Flink 实战
|
11月前
|
存储 分布式计算 算法
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
大数据-106 Spark Graph X 计算学习 案例:1图的基本计算、2连通图算法、3寻找相同的用户
182 0

热门文章

最新文章