import java.util.HashMap; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.model.RandomForestModel; import org.apache.spark.mllib.util.MLUtils; import scala.Tuple2; public class decisiontree1 { public static void main(String[] args) { SparkConf sparkConf = new SparkConf () . setAppName (" JavaRandomForestExample"); sparkConf . setMaster("local[2]"); JavaSparkContext sc = new JavaSparkContext (sparkConf); // Load and parse the data file. String datapath = "file:///home/gyq/下载/spark-2.3.2-bin-hadoop2.7/data/mllib/gg.txt"; JavaRDD<String> data=sc.textFile(datapath); JavaRDD<LabeledPoint> parseData= data.map(f->{ String[] parts=f.split(","); double[] v=new double[parts.length-1]; for(int i=0;i<parts.length-1;i++) { v[i]=Double.valueOf(parts[i]); } double label=0.0; if (parts[parts.length-1].equals("Iris-versicolor")) { label=0.0;} else if(parts[parts.length-1].equals("Iris-setosa")){ label=1.0;} else label=2.0; return new LabeledPoint(label,Vectors.dense(v)); }); parseData.foreach(f->System.out.println("label="+f.label())); JavaRDD<LabeledPoint>[] splits = parseData. randomSplit(new double[]{0.7, 0.3}); JavaRDD<LabeledPoint> trainingData = splits[0]; JavaRDD<LabeledPoint> testData = splits[1] ; // Train a RandomForest model. // Empty categoricalFeaturesInfo indicates all features are continuous . Integer numClasses = 3; HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); Integer numTrees =3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "gini"; Integer maxDepth = 5; Integer maxBins = 10; Integer seed = 12345; final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); // Evaluate model on test instances and compute test error JavaPairRDD<Double,Double> predictionAndLabel = testData. mapToPair(p->{ return new Tuple2<Double, Double> (model. predict(p. features()), p.label()); }); Double testErr =1.0 * predictionAndLabel. filter(pl->{ return !pl._1(). equals(pl._2()); }).count()/testData.count(); System. out. println("Test Error: "+ testErr); System. out. println("Learned classification forest model:\n" + model. toDebugString()); } }
数据类似这种