Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品

一、实验描述

数据集来源于UCI的银行营销数据集(UCI Machine Learning Repository: Bank Marketing Data Set)。数据与葡萄牙一家银行机构的直接营销活动有关。营销活动是以打电话为基础的。通常,需要与同一客户进行一次以上的联系,以便确认产品(银行定期存款)是否会订阅。
该数据集一共包含四个csv文件:

  1. bank-additional-full.csv
    包含所有的样例(41188个)和所有的特征输入(20个),根据时间排序(从2008年5月到2010年9月)。
  2. bank-additional.csv
    从1)中随机选出10%的样例4119个。
  3. bank-full.csv
    包含所有的样例(41188个)和17个特征输入,根据时间排序。(该数据集是更老的版本,特征输入较少);
  4. bank.csv

从3)中随机选出10%的样例4119个。

实验选取bank-additional-full.csv 文件,属性详情如下:

实验目的是通过构建决策树模型、随机森林模型,预测客户是否订阅产品(银行定期存款)。

二、实验代码

Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品_第1张图片

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{StructField, _}

object BankClassification {
    def unknownValueHandle(col: String, df: DataFrame): Unit ={
        // 求众数
        val mostFrequentValue = df.select(col).rdd.map(
            col_v=>(col_v.getString(0),1)
        ).groupByKey().sortBy(vc => vc._2, ascending = false).take(1)
    }

    def standardScale(col: String, df: DataFrame): DataFrame ={
        val stdScaler = new StandardScaler().
            setInputCol(col).
            setOutputCol(col + "Std").
            setWithMean(true). // 0 均值
            setWithStd(true) // 1 方差
        stdScaler.fit(df).transform(df)
    }

    def main(args: Array[String]): Unit ={
        val spark = SparkSession.builder().master("local").appName("BankDecisionTree").getOrCreate()
        spark.sparkContext.setLogLevel("ERROR")
        // 1. 获取数据集
        val rootPath = System.getProperty("user.dir")
        val bankPath = "file://"+rootPath+"/src/main/scala/bank-additional-full.csv"
        val schema = StructType(
            List(
                StructField("age", IntegerType, nullable = true),
                StructField("job", StringType, nullable = true),
                StructField("marital", StringType, nullable = true),
                StructField("education", StringType, nullable = true),
                StructField("default", StringType, nullable = true),
                StructField("housing", StringType, nullable = true),
                StructField("loan", StringType, nullable = true),
                StructField("contact", StringType, nullable = true),
                StructField("month", StringType, nullable = true),
                StructField("day_of_week", StringType, nullable = true),
                StructField("duration", DoubleType, nullable = true),
                StructField("campaign", IntegerType, nullable = true),
                StructField("pdays", IntegerType, nullable = true),
                StructField("previous", IntegerType, nullable = true),
                StructField("poutcome", StringType, nullable = true),
                StructField("emp_var_rate", DoubleType, nullable = true),
                StructField("cons_price_idx", DoubleType, nullable = true),
                StructField("cons_conf_idx", DoubleType, nullable = true),
                StructField("euribor3m", DoubleType, nullable = true),
                StructField("nr_employed", DoubleType, nullable = true),
                StructField("y", StringType, nullable = true),
            )
        )
        val bank = spark.read.option("header",value = true).option("sep",";").schema(schema).csv(bankPath)
        bank.printSchema()

        val numeric_cols = Array("age","duration","campaign",
            "pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed")
        val category_cols = Array(
            "job","marital","education","default","housing","loan","contact","month","day_of_week","poutcome"
        )

        // 1. EDA
        println("-------数据集前10行-------")
        println(bank.show(10))
        println("------描述性统计分析------")
        bank.describe("age","duration","campaign", "pdays","previous",
            "emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed").show()

        // 2. 数据预处理与特征工程
        // 2.1 缺失值处理:略
        // 2.2 类别型列哑变量编码
        // 2.2.1 先数值化
        val stringIndexer = new StringIndexer().setInputCols(Array(
            "job","marital","education","default","housing","loan","contact","month","day_of_week","poutcome"
        )).setOutputCols(Array(
            "jobIndex","maritalIndex","educationIndex","defaultIndex","housingIndex","loanIndex","contactIndex","monthIndex","day_of_weekIndex","poutcomeIndex"
        )).fit(bank)
        val indexedBank = stringIndexer.transform(bank)
        // 2.2.2 再哑变量化
        val oneHotEncoder = new OneHotEncoder().setInputCols(Array(
            "jobIndex","maritalIndex","educationIndex","defaultIndex","housingIndex","loanIndex","contactIndex","monthIndex","day_of_weekIndex","poutcomeIndex"
        )).setOutputCols(Array(
            "jobVector","maritaljobVector","educationjobVector","defaultjobVector","housingjobVector","loanjobVector","contactjobVector","monthjobVector","day_of_weekjobVector","poutcomejobVector"
        )).setDropLast(false)
        val oneHotEncodedBank = oneHotEncoder.fit(indexedBank).transform(indexedBank)

        // 2.3 数值型列标准差标准化
//        val numericCols = Array("age","duration","campaign",
//            "pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed")
//        var standardBank = oneHotEncodedBank
//        for (col <- numericCols) {
//            standardBank = standardScale(col, standardBank)
//        }

        // 2.4 标签编码
        val labelIndexer = new StringIndexer().setInputCol("y").setOutputCol("yIndex")
        val labeledBank = labelIndexer.fit(oneHotEncodedBank).transform(oneHotEncodedBank)

        // 3. 选取特征数组
        val featureCols = Array(
            "jobVector","maritaljobVector","educationjobVector","defaultjobVector","housingjobVector","loanjobVector","contactjobVector","monthjobVector","day_of_weekjobVector","poutcomejobVector",
            "age","duration","campaign",
            "pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed"
        )
        val vectorAssembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
        val df = vectorAssembler.transform(labeledBank)

        // 4. 数据集划分
        val Array(trainp,testp) = bank.randomSplit(Array(0.7,0.3),42)
        val Array(train,test) = df.randomSplit(Array(0.7,0.3),42)

        // 5. 模型训练
        // 5.1 随机森林
        for(i<-1 to 5){
            println("###############################")
        }
        val rfc = new RandomForestClassifier().setLabelCol("yIndex").setFeaturesCol("features")
        val rfcPipeline = new Pipeline().setStages(Array(stringIndexer, oneHotEncoder, labelIndexer, vectorAssembler, rfc))
        val rfcPipelineModel = rfcPipeline.fit(trainp)
        val rfcModel = rfcPipelineModel.stages(4).asInstanceOf[RandomForestClassificationModel]
        println("训练到的随机森林模型:\n" + rfcModel.toDebugString)
        val testPreds = rfcPipelineModel.transform(testp)
        println("随机森林(默认参数)预测结果=>")
        testPreds.select("y","yIndex","prediction").show(30)
        println("随机森林(默认参数)预测评分=>")
        val evaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
        val accuracy = evaluator.evaluate(testPreds)
        println(s"Accuracy:${accuracy}")
//        val rfcModel = rfc.fit(train)
//        println("训练到的随机森林模型:\n" + rfcModel.toDebugString)
//        val testPredictions = rfcModel.transform(test)
//        println("随机森林(默认参数)预测结果=>")
//        testPredictions.select("y","yIndex","prediction").show(30)
//        val evaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
//        println("随机森林(默认参数)预测评分=>")
//        val accuracy = evaluator.evaluate(testPredictions)
//        println(s"Accuracy:${accuracy}")

        // 5.2 决策树(网格搜索、交叉验证)
        for(i<-1 to 5){
            println("###############################")
        }
        val dtc = new DecisionTreeClassifier().setLabelCol("yIndex").setFeaturesCol("features")
        val paramGrid = new ParamGridBuilder().addGrid(
            dtc.maxDepth, Array(4,6,8,10,12)
        ).addGrid(
            dtc.impurity, Array("entropy","gini")
        ).build()
        val binaryClassificationEvaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
        val cv = new CrossValidator()
            .setEstimator(dtc)
            .setEstimatorParamMaps(paramGrid)
            .setEvaluator(binaryClassificationEvaluator)
            .setNumFolds(10)
        val cvModel = cv.fit(train)
        val testPredictions_ = cvModel.transform(test)
        println("决策树预测结果=>")
        testPredictions_.select("y","yIndex","prediction").show(30)
        println("决策树预测评分=>")
        println("Accuracy:"+binaryClassificationEvaluator.evaluate(testPredictions_))
        val dtcBest = cvModel.bestModel.asInstanceOf[DecisionTreeClassificationModel]
       println(dtcBest.explainParam(dtcBest.maxDepth))
        println(dtcBest.explainParam(dtcBest.impurity))    }
}

Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品_第2张图片
Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品_第3张图片
Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品_第4张图片
Spark ML(lib)实验:利用银行营销数据集预测客户是否订阅产品_第5张图片

你可能感兴趣的:(大数据,Application,spark)