决策树代码共享

简介: 决策树代码共享
import java.util.HashMap;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;
public class decisiontree1 {
  public static void main(String[] args) {
  SparkConf sparkConf = new SparkConf () . setAppName (" JavaRandomForestExample"); 
  sparkConf . setMaster("local[2]");
  JavaSparkContext sc = new JavaSparkContext (sparkConf); // Load and parse the data file.
  String datapath =
      "file:///home/gyq/下载/spark-2.3.2-bin-hadoop2.7/data/mllib/gg.txt";
  JavaRDD<String> data=sc.textFile(datapath);
  JavaRDD<LabeledPoint> parseData= data.map(f->{
    String[] parts=f.split(",");
    double[] v=new double[parts.length-1];
    for(int i=0;i<parts.length-1;i++) {
      v[i]=Double.valueOf(parts[i]);
    }
    double label=0.0;
    if (parts[parts.length-1].equals("Iris-versicolor")) {
      label=0.0;}
      else if(parts[parts.length-1].equals("Iris-setosa")){
      label=1.0;}
    else label=2.0;
    return new LabeledPoint(label,Vectors.dense(v));
  });
  parseData.foreach(f->System.out.println("label="+f.label()));
    JavaRDD<LabeledPoint>[] splits = parseData. randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1] ;
  // Train a RandomForest model.
  // Empty categoricalFeaturesInfo indicates all features are continuous .
  Integer numClasses = 3;
  HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
  Integer numTrees =3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "gini";
  Integer maxDepth = 5;
  Integer maxBins = 10;
  Integer seed = 12345;
  final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
      categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
      seed);
      // Evaluate model on test instances and compute test error
      JavaPairRDD<Double,Double> predictionAndLabel =
      testData. mapToPair(p->{
      return new Tuple2<Double, Double> (model. predict(p. features()), p.label());
      });
      Double testErr =1.0 * predictionAndLabel. filter(pl->{
      return !pl._1(). equals(pl._2());
      }).count()/testData.count();
      System. out. println("Test Error: "+ testErr);
      System. out. println("Learned classification forest model:\n" + model. toDebugString());
  }
}


数据类似这种

4.1.png

相关文章
|
3月前
|
数据采集
PCA与主成分回归(PCR)有何区别?
PCA是降维工具,转化相关变量为线性无关的主成分,保留数据变异。PCR是回归分析方法,利用PCA的主成分预测因变量,应对自变量间的多重共线性,提升模型稳定性。两者协同工作,优化高维数据的建模。
175 0
|
13天前
|
机器学习/深度学习 算法
【机器学习】不同决策树的节点分裂准则(属性划分标准)
决策树的不同节点分裂准则,包括原始决策树的节点分裂准则、ID3算法的信息增益、C4.5算法的信息增益比以及CART算法的平方根误差最小化和基尼指数。
19 1
|
13天前
|
机器学习/深度学习 算法 数据挖掘
【数据挖掘】 GBDT面试题:其中基分类器CART回归树,节点的分裂标准是什么?与RF的区别?与XGB的区别?
文章讨论了梯度提升决策树(GBDT)中的基分类器CART回归树的节点分裂标准,并比较了GBDT与随机森林(RF)和XGBoost(XGB)的区别,包括集成学习方式、偏差-方差权衡、样本使用、并行性、最终结果融合、数据敏感性以及泛化能力等方面的不同。
24 1
|
3月前
|
机器学习/深度学习 算法 数据挖掘
一文介绍回归和分类的本质区别 !!
一文介绍回归和分类的本质区别 !!
113 0
|
3月前
|
机器学习/深度学习 存储 编解码
RepQ带来重参结构新突破 | RepVGG结构真的没办法进行QAT训练吗?
RepQ带来重参结构新突破 | RepVGG结构真的没办法进行QAT训练吗?
87 1
RepQ带来重参结构新突破 | RepVGG结构真的没办法进行QAT训练吗?
|
3月前
|
机器学习/深度学习 JavaScript 前端开发
机器学习 - [源码实现决策树小专题]决策树中子数据集的划分(不允许调用sklearn等库的源代码实现)
机器学习 - [源码实现决策树小专题]决策树中子数据集的划分(不允许调用sklearn等库的源代码实现)
50 0
|
3月前
|
机器学习/深度学习 运维 算法
决策树算法的用途
决策树算法的用途
|
8月前
|
机器学习/深度学习 自然语言处理
神经网络的权值共享有哪些方式
神经网络的权值共享有哪些方式
121 0
|
机器学习/深度学习 PyTorch 测试技术
使用PyTorch构建神经网络(详细步骤讲解+注释版) 01-建立分类器类
神经网络中,一个非常经典的案例就是手写数据的识别,本文我们以手写数据识别为例进行讲解。用到的数据是MNIST数据集。MNIST数据集是一个常用的用于计算机视觉的测试数据集,包含了70,000张手写数字的图片,用于训练和测试模型识别手写数字的能力。MNIST数据集中的图片大小都是28x28像素,图片中的数字是黑白的,每张图片都有对应的标签,表示图片中的数字是什么。MNIST数据集是计算机视觉领域的“Hello World”级别的数据集,被广泛用于计算机视觉模型的训练和测试。
数据集划分方式(误差的评估方法)
留出法(hold out)、交叉验证法(cross validation)、留一法、自助法:(可重复采样,有放回的采样操作)
150 0
数据集划分方式(误差的评估方法)