数据集来源于UCI的银行营销数据集(UCI Machine Learning Repository: Bank Marketing Data Set)。数据与葡萄牙一家银行机构的直接营销活动有关。营销活动是以打电话为基础的。通常,需要与同一客户进行一次以上的联系,以便确认产品(银行定期存款)是否会订阅。
该数据集一共包含四个csv文件:
从3)中随机选出10%的样例4119个。
实验选取bank-additional-full.csv 文件,属性详情如下:
实验目的是通过构建决策树模型、随机森林模型,预测客户是否订阅产品(银行定期存款)。
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{StructField, _}
object BankClassification {
def unknownValueHandle(col: String, df: DataFrame): Unit ={
// 求众数
val mostFrequentValue = df.select(col).rdd.map(
col_v=>(col_v.getString(0),1)
).groupByKey().sortBy(vc => vc._2, ascending = false).take(1)
}
def standardScale(col: String, df: DataFrame): DataFrame ={
val stdScaler = new StandardScaler().
setInputCol(col).
setOutputCol(col + "Std").
setWithMean(true). // 0 均值
setWithStd(true) // 1 方差
stdScaler.fit(df).transform(df)
}
def main(args: Array[String]): Unit ={
val spark = SparkSession.builder().master("local").appName("BankDecisionTree").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
// 1. 获取数据集
val rootPath = System.getProperty("user.dir")
val bankPath = "file://"+rootPath+"/src/main/scala/bank-additional-full.csv"
val schema = StructType(
List(
StructField("age", IntegerType, nullable = true),
StructField("job", StringType, nullable = true),
StructField("marital", StringType, nullable = true),
StructField("education", StringType, nullable = true),
StructField("default", StringType, nullable = true),
StructField("housing", StringType, nullable = true),
StructField("loan", StringType, nullable = true),
StructField("contact", StringType, nullable = true),
StructField("month", StringType, nullable = true),
StructField("day_of_week", StringType, nullable = true),
StructField("duration", DoubleType, nullable = true),
StructField("campaign", IntegerType, nullable = true),
StructField("pdays", IntegerType, nullable = true),
StructField("previous", IntegerType, nullable = true),
StructField("poutcome", StringType, nullable = true),
StructField("emp_var_rate", DoubleType, nullable = true),
StructField("cons_price_idx", DoubleType, nullable = true),
StructField("cons_conf_idx", DoubleType, nullable = true),
StructField("euribor3m", DoubleType, nullable = true),
StructField("nr_employed", DoubleType, nullable = true),
StructField("y", StringType, nullable = true),
)
)
val bank = spark.read.option("header",value = true).option("sep",";").schema(schema).csv(bankPath)
bank.printSchema()
val numeric_cols = Array("age","duration","campaign",
"pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed")
val category_cols = Array(
"job","marital","education","default","housing","loan","contact","month","day_of_week","poutcome"
)
// 1. EDA
println("-------数据集前10行-------")
println(bank.show(10))
println("------描述性统计分析------")
bank.describe("age","duration","campaign", "pdays","previous",
"emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed").show()
// 2. 数据预处理与特征工程
// 2.1 缺失值处理:略
// 2.2 类别型列哑变量编码
// 2.2.1 先数值化
val stringIndexer = new StringIndexer().setInputCols(Array(
"job","marital","education","default","housing","loan","contact","month","day_of_week","poutcome"
)).setOutputCols(Array(
"jobIndex","maritalIndex","educationIndex","defaultIndex","housingIndex","loanIndex","contactIndex","monthIndex","day_of_weekIndex","poutcomeIndex"
)).fit(bank)
val indexedBank = stringIndexer.transform(bank)
// 2.2.2 再哑变量化
val oneHotEncoder = new OneHotEncoder().setInputCols(Array(
"jobIndex","maritalIndex","educationIndex","defaultIndex","housingIndex","loanIndex","contactIndex","monthIndex","day_of_weekIndex","poutcomeIndex"
)).setOutputCols(Array(
"jobVector","maritaljobVector","educationjobVector","defaultjobVector","housingjobVector","loanjobVector","contactjobVector","monthjobVector","day_of_weekjobVector","poutcomejobVector"
)).setDropLast(false)
val oneHotEncodedBank = oneHotEncoder.fit(indexedBank).transform(indexedBank)
// 2.3 数值型列标准差标准化
// val numericCols = Array("age","duration","campaign",
// "pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed")
// var standardBank = oneHotEncodedBank
// for (col <- numericCols) {
// standardBank = standardScale(col, standardBank)
// }
// 2.4 标签编码
val labelIndexer = new StringIndexer().setInputCol("y").setOutputCol("yIndex")
val labeledBank = labelIndexer.fit(oneHotEncodedBank).transform(oneHotEncodedBank)
// 3. 选取特征数组
val featureCols = Array(
"jobVector","maritaljobVector","educationjobVector","defaultjobVector","housingjobVector","loanjobVector","contactjobVector","monthjobVector","day_of_weekjobVector","poutcomejobVector",
"age","duration","campaign",
"pdays","previous","emp_var_rate","cons_price_idx","cons_conf_idx", "euribor3m", "nr_employed"
)
val vectorAssembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df = vectorAssembler.transform(labeledBank)
// 4. 数据集划分
val Array(trainp,testp) = bank.randomSplit(Array(0.7,0.3),42)
val Array(train,test) = df.randomSplit(Array(0.7,0.3),42)
// 5. 模型训练
// 5.1 随机森林
for(i<-1 to 5){
println("###############################")
}
val rfc = new RandomForestClassifier().setLabelCol("yIndex").setFeaturesCol("features")
val rfcPipeline = new Pipeline().setStages(Array(stringIndexer, oneHotEncoder, labelIndexer, vectorAssembler, rfc))
val rfcPipelineModel = rfcPipeline.fit(trainp)
val rfcModel = rfcPipelineModel.stages(4).asInstanceOf[RandomForestClassificationModel]
println("训练到的随机森林模型:\n" + rfcModel.toDebugString)
val testPreds = rfcPipelineModel.transform(testp)
println("随机森林(默认参数)预测结果=>")
testPreds.select("y","yIndex","prediction").show(30)
println("随机森林(默认参数)预测评分=>")
val evaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
val accuracy = evaluator.evaluate(testPreds)
println(s"Accuracy:${accuracy}")
// val rfcModel = rfc.fit(train)
// println("训练到的随机森林模型:\n" + rfcModel.toDebugString)
// val testPredictions = rfcModel.transform(test)
// println("随机森林(默认参数)预测结果=>")
// testPredictions.select("y","yIndex","prediction").show(30)
// val evaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
// println("随机森林(默认参数)预测评分=>")
// val accuracy = evaluator.evaluate(testPredictions)
// println(s"Accuracy:${accuracy}")
// 5.2 决策树(网格搜索、交叉验证)
for(i<-1 to 5){
println("###############################")
}
val dtc = new DecisionTreeClassifier().setLabelCol("yIndex").setFeaturesCol("features")
val paramGrid = new ParamGridBuilder().addGrid(
dtc.maxDepth, Array(4,6,8,10,12)
).addGrid(
dtc.impurity, Array("entropy","gini")
).build()
val binaryClassificationEvaluator = new BinaryClassificationEvaluator().setLabelCol("yIndex")
val cv = new CrossValidator()
.setEstimator(dtc)
.setEstimatorParamMaps(paramGrid)
.setEvaluator(binaryClassificationEvaluator)
.setNumFolds(10)
val cvModel = cv.fit(train)
val testPredictions_ = cvModel.transform(test)
println("决策树预测结果=>")
testPredictions_.select("y","yIndex","prediction").show(30)
println("决策树预测评分=>")
println("Accuracy:"+binaryClassificationEvaluator.evaluate(testPredictions_))
val dtcBest = cvModel.bestModel.asInstanceOf[DecisionTreeClassificationModel]
println(dtcBest.explainParam(dtcBest.maxDepth))
println(dtcBest.explainParam(dtcBest.impurity)) }
}