import java.util.HashMap; import java.util.Map; import scala.Tuple2; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.apache.spark.mllib.util.MLUtils; public class DecisionTreeRegression{ public static void main(String[] args) { // TODO Auto-generated method stub SparkConf sparkConf = new SparkConf(). setAppName ("JavaDecisionTreeClassificationExample"); sparkConf . setMaster("local[2]"); JavaSparkContext jsc = new JavaSparkContext (sparkConf); // Load and parse the data file. String datapath = "file:///home/gyq/下载/spark-2.3.2-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt"; JavaRDD<LabeledPoint> data = MLUtils. loadLibSVMFile(jsc.sc(), datapath).toJavaRDD() ; // Split the data into training and test sets (30% held out for testing) JavaRDD<LabeledPoint>[] splits = data. randomSplit(new double[]{0.7, 0.3}); JavaRDD<LabeledPoint> trainingData = splits[0]; JavaRDD<LabeledPoint> testData = splits[1] ; // Set parameters. // Empty categoricalFeaturesInfo indicates all features are cont inuous . Integer numClasses = 2; //类别数量 Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>() ; /*衡量分类的质量。 支持的标准有"gini" ,代表的是Gini impurity(不纯度,即无序程度)与“entropy"代表的是 information gain(信息增益) */ String impurity = "gini"; Integer maxDepth = 5; // 最大深度 Integer maxBins = 32; // 最大划分数 // Train a DecisionTree model for classification. final DecisionTreeModel model = DecisionTree . trainClassifier(trainingData, numClasses , categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on test instances and compute test error JavaPairRDD<Double,Double> predictionAndLabel = testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { public Tuple2<Double,Double> call(LabeledPoint p) { return new Tuple2<Double, Double> (model. predict(p. features()),p.label()); } }); Double testErr =1.0 * predictionAndLabel. filter(new Function<Tuple2<Double,Double>, Boolean>() { public Boolean call(Tuple2<Double,Double> pl) { return !pl._1(). equals(pl._2()); } }). count() / testData . count(); System. out . println("Test Error: "+ testErr); System. out . println("Learned classification tree model:\n" + model. toDebugString()); } }