spark 决策树分类 DecisionTreeClassifier

为什么80%的码农都做不了架构师?>>>   hot3.png

决策树分类是一个非概率模型,测试数据集用的是网上公开的泰坦尼克号乘客数据,用决策树DecisionTreeClassifier的数据挖掘算法来通过三个参数,Pclass,Sex,Age,三个参数来预测乘客的获救率。
pom.xml


  4.0.0
  com.penngo.spark.ml
  sparkml
  jar
  1.0-SNAPSHOT
  sparkml
  http://maven.apache.org
  
    UTF-8
    UTF-8
    1.8
  
  
    
      junit
      junit
      3.8.1
      test
    
	
      org.apache.spark
      spark-core_2.11
      2.2.3
    
    
      org.apache.spark
      spark-sql_2.11
      2.2.3
    
    
      org.apache.spark
      spark-mllib_2.11
      2.2.3
    
    
      org.apache.spark
      spark-streaming_2.11
      2.2.3
    
  
	
    
      
        org.apache.maven.plugins
        maven-compiler-plugin
        3.7.0
        
          1.8
          1.8
          UTF-8
        
      
    
    

DecisionTreeClassification.java

package com.penngo.spark.ml.main;

import org.apache.log4j.Logger;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.File;
import org.apache.spark.sql.functions;
import static org.apache.spark.sql.types.DataTypes.DoubleType;

/**
 * spark 决策树分类 DecisionTreeClassifier
 *
 */
public class DecisionTreeClassification {
    private static Logger log = Logger.getLogger(DecisionTreeClassification.class);
    private static SparkSession spark = null;



    public static void initSpark(){
        if (spark == null) {
            String os = System.getProperty("os.name").toLowerCase();
            // linux上运行
            if(os.indexOf("windows") == -1){
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification")
                        .getOrCreate();
            }
            // window上运行,本机调试
            else{
                System.setProperty("hadoop.home.dir", "D:/hadoop/hadoop-2.7.6");
                System.setProperty("HADOOP_USER_NAME", "hadoop");
                spark = SparkSession
                        .builder()
                        .appName("DecisionTreeClassification" ).master("local[3]")
                        .getOrCreate();
            }
        }
        log.warn("spark.conf().getAll()=============" + spark.conf().getAll());
    }

    public static void run(){
        String dataPath = new File("").getAbsolutePath() + "/data/titanic.txt";
        Dataset data = spark.read().option("header", "true").csv(dataPath);
        data.show();
        //data.describe()
        //Dataset datana2 = data.na().fill(ImmutableMap.of("age", "30", "ticket", "1111"));

        Dataset meanDataset = data.select(functions.mean("age").as("mage"));
        Double mage = meanDataset.first().getAs("mage");
        // 字符串转换为数据,处理空值
        Dataset data1 = data.select(
                functions.col("user_id"),
                functions.col("survived").cast(DoubleType).as("label"),
                functions.when(functions.col("pclass").equalTo("1st"), 1)
                        .when(functions.col("pclass").equalTo("2nd"), 2)
                        .when(functions.col("pclass").equalTo("3rd"), 3)
                        .cast(DoubleType).as("pclass1"),
                functions.when(functions.col("age").equalTo("NA"), mage.intValue()).otherwise(functions.col("age")).cast(DoubleType).as("age1"),
                functions.when(functions.col("sex").equalTo("female"), 0).otherwise(1).as("sex")
        );

        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(new String[]{"pclass1", "age1", "sex"})
                .setOutputCol("features");
        Dataset data2 = assembler.transform(data1);
        data2.show();
        // 索引标签,将元数据添加到标签列中
        StringIndexerModel labelIndexer = new StringIndexer()
                .setInputCol("label")
                .setOutputCol("indexedLabel")
                .fit(data2);
        // 自动识别分类的特征,并对它们进行索引
        // 具有大于5个不同的值的特征被视为连续。
        VectorIndexerModel featureIndexer = new VectorIndexer()
                .setInputCol("features")
                .setOutputCol("indexedFeatures")
                //.setMaxCategories(3)
                .fit(data2);
        // 将数据分为训练和测试集(30%进行测试)
        Dataset[] splits = data2.randomSplit(new double[]{0.7, 0.3});
        Dataset trainingData = splits[0];
        Dataset testData = splits[1];

        // 训练决策树模型
        DecisionTreeClassifier dt = new DecisionTreeClassifier()
                .setLabelCol("indexedLabel")
                .setFeaturesCol("indexedFeatures");
        //.setImpurity("entropy") // Gini不纯度,entropy熵
        //.setMaxBins(100) // 离散化"连续特征"的最大划分数
        //.setMaxDepth(5) // 树的最大深度
        //.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1]
        //.setMinInstancesPerNode(10) //每个节点包含的最小样本数
        //.setSeed(123456)

        IndexToString labelConverter = new IndexToString()
                .setInputCol("prediction")
                .setOutputCol("predictedLabel")
                .setLabels(labelIndexer.labels());


        // Chain indexers and tree in a Pipeline.
        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});

        // 训练模型
        PipelineModel model = pipeline.fit(trainingData);

        // 预测数据
        Dataset predictions = model.transform(testData);

        predictions.select("user_id", "features", "label", "prediction").show();
        //predictions.select("predictedLabel", "label", "features").show(5);

        // 计算错误率
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("indexedLabel")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));

        // 查看决策树
        DecisionTreeClassificationModel treeModel =
                (DecisionTreeClassificationModel) (model.stages()[2]);
        System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
        // $example off$

        spark.stop();
    }
    public static void main(String[] args){
        initSpark();
        run();
    }
}

基础数据

spark 决策树分类 DecisionTreeClassifier_第1张图片

过滤、特征化后的数据

spark 决策树分类 DecisionTreeClassifier_第2张图片
预测结果

spark 决策树分类 DecisionTreeClassifier_第3张图片
预测错误率和预测模型

spark 决策树分类 DecisionTreeClassifier_第4张图片

转载于:https://my.oschina.net/penngo/blog/3018547

你可能感兴趣的:(大数据,python,数据结构与算法)