Spark Machine Learning进行数据挖掘的简单应用(兴趣预测问题)

数据挖掘的过程

数据挖掘任务主要分为以下六个步骤:

  • 1.数据预处理
  • 2.特征转换
  • 3.特征选择
  • 4.训练模型
  • 5.模型预测
  • 6.评估预测结果

数据准备

这里准备了20条关于不同地区、不同性别、不同身高、体重…的人的兴趣数据集(命名为hobby.csv):

id,hobby,sex,address,age,height,weight
1,football,male,dalian,12,168,55
2,pingpang,female,yangzhou,21,163,60
3,football,male,dalian,,172,70
4,football,female,,13,167,58
5,pingpang,female,shanghai,63,170,64
6,football,male,dalian,30,177,76
7,basketball,male,shanghai,25,181,90
8,football,male,dalian,15,172,71
9,basketball,male,shanghai,25,179,80
10,pingpang,male,shanghai,55,175,72
11,football,male,dalian,13,169,55
12,pingpang,female,yangzhou,22,164,61
13,football,male,dalian,23,170,71
14,football,female,,12,164,55
15,pingpang,female,shanghai,64,169,63
16,football,male,dalian,30,177,76
17,basketball,male,shanghai,22,180,80
18,football,male,dalian,16,173,72
19,basketball,male,shanghai,23,176,73
20,pingpang,male,shanghai,56,171,71
  • 任务分析
    通过sex,address,age,height,weight这五个特征预测一个人的兴趣爱好

数据预处理

想要连接数据,必须先创建一个spark对象

定义Spark对象

使用SparkSession中的builder()构建 后续设定appName 和master ,最后使用getOrCreate()完成构建

    // 定义spark对象
    val spark = SparkSession.builder().appName("兴趣预测").master("local[*]").getOrCreate()

连接数据

使用spark.read连接数据,需要指定数据的格式为“CSV”,将首行设置为header,最后指定文件路径:

val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")

使用df.show() df.printSchema()查看数据:

    df.show()
    df.printSchema()

    spark.stop()  // 关闭spark

输出信息:

+---+----------+------+--------+----+------+------+
| id|     hobby|   sex| address| age|height|weight|
+---+----------+------+--------+----+------+------+
|  1|  football|  male|  dalian|  12|   168|    55|
|  2|  pingpang|female|yangzhou|  21|   163|    60|
|  3|  football|  male|  dalian|null|   172|    70|
|  4|  football|female|    null|  13|   167|    58|
|  5|  pingpang|female|shanghai|  63|   170|    64|
|  6|  football|  male|  dalian|  30|   177|    76|
|  7|basketball|  male|shanghai|  25|   181|    90|
|  8|  football|  male|  dalian|  15|   172|    71|
|  9|basketball|  male|shanghai|  25|   179|    80|
| 10|  pingpang|  male|shanghai|  55|   175|    72|
| 11|  football|  male|  dalian|  13|   169|    55|
| 12|  pingpang|female|yangzhou|  22|   164|    61|
| 13|  football|  male|  dalian|  23|   170|    71|
| 14|  football|female|    null|  12|   164|    55|
| 15|  pingpang|female|shanghai|  64|   169|    63|
| 16|  football|  male|  dalian|  30|   177|    76|
| 17|basketball|  male|shanghai|  22|   180|    80|
| 18|  football|  male|  dalian|  16|   173|    72|
| 19|basketball|  male|shanghai|  23|   176|    73|
| 20|  pingpang|  male|shanghai|  56|   171|    71|
+---+----------+------+--------+----+------+------+

root
 |-- id: string (nullable = true)
 |-- hobby: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- address: string (nullable = true)
 |-- age: string (nullable = true)
 |-- height: string (nullable = true)
 |-- weight: string (nullable = true)

补全年龄空缺的行

补全数值型数据可以分三步:
(1)取出去除空行数据之后的这一列数据
(2)计算(1)中那一列数据的平均值
(3)将平均值填充至原先的表中

  • (1)取出空行之后的数据
    val ageNaDF = df.select("age").na.drop()
    ageNaDF.show()
+---+
|age|
+---+
| 12|
| 21|
| 13|
| 63|
| 30|
| 25|
| 15|
| 25|
| 55|
| 13|
| 22|
| 23|
| 12|
| 64|
| 30|
| 22|
| 16|
| 23|
| 56|
+---+
  • (2)计算(1)中那一列数据的平均值

查看ageNaDF的基本特征

ageNaDF.describe("age").show()

输出:

+-------+-----------------+
|summary|              age|
+-------+-----------------+
|  count|               19|
|   mean|28.42105263157895|
| stddev|17.48432882286206|
|    min|               12|
|    max|               64|
+-------+-----------------+

可以看到其中的均值mean为28.42105263157895,我们需要取出这个mean

    val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
    print(mean) //28.42105263157895
  • (3)将平均值填充至原先的表中
    使用df.na.fill()方法可以填充空值,需要指定列为“age”,所以第二个参数为List(“age”)
    val ageFilledDF = df.na.fill(mean,List("age"))
    ageFilledDF.show()

输出:

+---+----------+------+--------+-----------------+------+------+
| id|     hobby|   sex| address|              age|height|weight|
+---+----------+------+--------+-----------------+------+------+
|  1|  football|  male|  dalian|               12|   168|    55|
|  2|  pingpang|female|yangzhou|               21|   163|    60|
|  3|  football|  male|  dalian|28.42105263157895|   172|    70|
|  4|  football|female|    null|               13|   167|    58|
|  5|  pingpang|female|shanghai|               63|   170|    64|
|  6|  football|  male|  dalian|               30|   177|    76|
|  7|basketball|  male|shanghai|               25|   181|    90|
|  8|  football|  male|  dalian|               15|   172|    71|
|  9|basketball|  male|shanghai|               25|   179|    80|
| 10|  pingpang|  male|shanghai|               55|   175|    72|
| 11|  football|  male|  dalian|               13|   169|    55|
| 12|  pingpang|female|yangzhou|               22|   164|    61|
| 13|  football|  male|  dalian|               23|   170|    71|
| 14|  football|female|    null|               12|   164|    55|
| 15|  pingpang|female|shanghai|               64|   169|    63|
| 16|  football|  male|  dalian|               30|   177|    76|
| 17|basketball|  male|shanghai|               22|   180|    80|
| 18|  football|  male|  dalian|               16|   173|    72|
| 19|basketball|  male|shanghai|               23|   176|    73|
| 20|  pingpang|  male|shanghai|               56|   171|    71|
+---+----------+------+--------+-----------------+------+------+

可以发现年龄中的空值被填充了平均值

删除城市有空值所在的行

由于城市的列没有合理的数据可以填充,所以如果城市出现空数据则选择把改行删除
使用.na.drop()方法

    val addressDf = ageFilledDF.na.drop()
    addressDf.show()

输出:

+---+----------+------+--------+-----------------+------+------+
| id|     hobby|   sex| address|              age|height|weight|
+---+----------+------+--------+-----------------+------+------+
|  1|  football|  male|  dalian|               12|   168|    55|
|  2|  pingpang|female|yangzhou|               21|   163|    60|
|  3|  football|  male|  dalian|28.42105263157895|   172|    70|
|  5|  pingpang|female|shanghai|               63|   170|    64|
|  6|  football|  male|  dalian|               30|   177|    76|
|  7|basketball|  male|shanghai|               25|   181|    90|
|  8|  football|  male|  dalian|               15|   172|    71|
|  9|basketball|  male|shanghai|               25|   179|    80|
| 10|  pingpang|  male|shanghai|               55|   175|    72|
| 11|  football|  male|  dalian|               13|   169|    55|
| 12|  pingpang|female|yangzhou|               22|   164|    61|
| 13|  football|  male|  dalian|               23|   170|    71|
| 15|  pingpang|female|shanghai|               64|   169|    63|
| 16|  football|  male|  dalian|               30|   177|    76|
| 17|basketball|  male|shanghai|               22|   180|    80|
| 18|  football|  male|  dalian|               16|   173|    72|
| 19|basketball|  male|shanghai|               23|   176|    73|
| 20|  pingpang|  male|shanghai|               56|   171|    71|
+---+----------+------+--------+-----------------+------+------+

4和14行被删除

将每列字段的格式转换成合理的格式

    //对df的schema进行调整
    val formatDF = addressDf.select(
      col("id").cast("int"),
      col("hobby").cast("String"),
      col("sex").cast("String"),
      col("address").cast("String"),
      col("age").cast("Double"),
      col("height").cast("Double"),
      col("weight").cast("Double")
    )
    formatDF.printSchema()

输出:

root
 |-- id: integer (nullable = true)
 |-- hobby: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- address: string (nullable = true)
 |-- age: double (nullable = true)
 |-- height: double (nullable = true)
 |-- weight: double (nullable = true)

到此,数据预处理部分完成。

特征转换

为了便于模型训练,在数据的特征转换中,我们需要对age、weight、height、address、sex这些特征做分桶处理。

对年龄做分桶处理

  • 18以下
  • 18-35
  • 35-60
  • 60以上

使用Bucketizer类用来分桶处理,需要设置输入的列名和输出的列名,把定义的分桶区间作为这个类分桶的依据,最后给定需要做分桶处理的DataFrame

    //2.1 对年龄进行分桶处理
    //定义一个数组作为分桶的区间
    val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
    val bucketizerDF = new Bucketizer()
      .setInputCol("age")
      .setOutputCol("ageFeature")
      .setSplits(ageSplits)
      .transform(formatDF)
    bucketizerDF.show()

查看分桶结果:

+---+----------+------+--------+-----------------+------+------+----------+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|
+---+----------+------+--------+-----------------+------+------+----------+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|
+---+----------+------+--------+-----------------+------+------+----------+

对身高做二值化处理

基准为170 使用Binarizer类

    //2.2 对身高做二值化处理
    val heightDF = new Binarizer()
      .setInputCol("height")
      .setOutputCol("heightFeature")
      .setThreshold(170) // 阈值
      .transform(bucketizerDF)
    heightDF.show()

查看处理后结果:

+---+----------+------+--------+-----------------+------+------+----------+-------------+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|heightFeature|
+---+----------+------+--------+-----------------+------+------+----------+-------------+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|          0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|          0.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|          1.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|          0.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|          1.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|          1.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|          1.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|          1.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|          0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|          0.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|          0.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|          0.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|          1.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|          1.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|          1.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|          1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+

对体重做二值化处理

阈值设为 65

    //2.3 对体重做二值化处理
    val weightDF = new Binarizer()
      .setInputCol("weight")
      .setOutputCol("weightFeature")
      .setThreshold(65)
      .transform(heightDF)

    weightDF.show()

性别、城市、爱好字段的处理

这三个字段都是字符串,而字符串的形式在机器学习中是不适合做分析处理的,所以也需要对他们做特征转换(编码处理)。

    //2.4 对性别进行labelEncode转换
    val sexIndex = new StringIndexer()
      .setInputCol("sex")
      .setOutputCol("sexIndex")
      .fit(weightDF)
      .transform(weightDF)

    //2.5对家庭地址进行labelEncode转换
    val addIndex = new StringIndexer()
      .setInputCol("address")
      .setOutputCol("addIndex")
      .fit(sexIndex)
      .transform(sexIndex)

    //2.6对地址进行one-hot编码
    val addOneHot = new OneHotEncoder()
      .setInputCol("addIndex")
      .setOutputCol("addOneHot")
      .fit(addIndex)
      .transform(addIndex)

    //2.7对兴趣字段进行LabelEncode处理
    val hobbyIndexDF = new StringIndexer()
      .setInputCol("hobby")
      .setOutputCol("hobbyIndex")
      .fit(addOneHot)
      .transform(addOneHot)

    hobbyIndexDF.show()

这里额外对地址做了一个one-hot处理。

将hobbyIndex列名称改成label,因为hobby在模型训练阶段用作标签。

    //2.8修改列名
    val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")
    resultDF.show()

最终特征转换后的结果:

+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|heightFeature|weightFeature|sexIndex|addIndex|    addOneHot|label|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|          0.0|          0.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|          0.0|          0.0|     1.0|     2.0|    (2,[],[])|  1.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|          0.0|          0.0|     1.0|     1.0|(2,[1],[1.0])|  1.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  1.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|          0.0|          0.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|          0.0|          0.0|     1.0|     2.0|    (2,[],[])|  1.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|          0.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|          0.0|          0.0|     1.0|     1.0|(2,[1],[1.0])|  1.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+

特征选择

特征转换后的结果是一个多列数据,但不是所有的列都可以拿来用作机器学习的模型训练,特征选择就是要选择可以用来机器学习的数据。

选择特征

使用VectorAssembler()可以将需要的列取出

    //3.1选择特征
    val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addIndex","label"))
      .setOutputCol("features")

特征进行规范化处理

    val scaler = new StandardScaler()
      .setInputCol("features")
      .setOutputCol("featureScaler")
      .setWithStd(true) // 是否使用标准差
      .setWithMean(false)  // 是否使用中位数

特征筛选

    // 特征筛选,使用卡方检验方法来做筛选
    val selector = new ChiSqSelector()
      .setLabelCol("label")
      .setOutputCol("featuresSelector")

构建逻辑回归模型和pipline

    // 逻辑回归模型
    val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")

    // 构造pipeline
    val pipeline = new Pipeline().setStages(Array(vectorAssembler,scaler,selector,lr))

设置网络搜索最佳参数

    // 设置网络搜索最佳参数
    val params = new ParamGridBuilder()
      .addGrid(lr.regParam,Array(0.1,0.01))  //正则化参数
      .addGrid(selector.numTopFeatures,Array(5,10,5))  //设置卡方检验最佳特征数
      .build()

设置交叉检验

    // 设置交叉检验
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator())
      .setEstimatorParamMaps(params)
      .setNumFolds(5)

模型训练与预测

模型训练前需要拆分一下训练集和测试集

val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))

使用randomSplit方法可以完成拆分

  • 开始训练和预测
    val model = cv.fit(trainDF)

    // 模型预测
    val preddiction = model.bestModel.transform(testDF)
    preddiction.show()

报错求解决

运行cv.fit(trainDF)的地方报错了 这个信息网上也没找到

Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/spark/sql/catalyst/trees/BinaryLike
	at java.lang.ClassLoader.defineClass1(Native Method)
	at java.lang.ClassLoader.defineClass(ClassLoader.java:756)
	at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
	at java.net.URLClassLoader.defineClass(URLClassLoader.java:473)
	at java.net.URLClassLoader.access$100(URLClassLoader.java:74)
	at java.net.URLClassLoader$1.run(URLClassLoader.java:369)
	at java.net.URLClassLoader$1.run(URLClassLoader.java:363)
	at java.security.AccessController.doPrivileged(Native Method)
	at java.net.URLClassLoader.findClass(URLClassLoader.java:362)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
	at org.apache.spark.ml.stat.SummaryBuilderImpl.summary(Summarizer.scala:251)
	at org.apache.spark.ml.stat.SummaryBuilder.summary(Summarizer.scala:54)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:112)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:84)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$5(Pipeline.scala:151)
	at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
	at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
	at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$4(Pipeline.scala:151)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$2(Pipeline.scala:147)
	at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
	at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
	at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$1(Pipeline.scala:133)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:133)
	at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:93)
	at org.apache.spark.ml.Estimator.fit(Estimator.scala:59)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$7(CrossValidator.scala:174)
	at scala.runtime.java8.JFunction0$mcD$sp.apply(JFunction0$mcD$sp.java:23)
	at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
	at scala.util.Success.$anonfun$map$1(Try.scala:255)
	at scala.util.Success.map(Try.scala:213)
	at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
	at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
	at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
	at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
	at org.sparkproject.guava.util.concurrent.MoreExecutors$SameThreadExecutorService.execute(MoreExecutors.java:293)
	at scala.concurrent.impl.ExecutionContextImpl$$anon$4.execute(ExecutionContextImpl.scala:138)
	at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:72)
	at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete(Promise.scala:372)
	at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete$(Promise.scala:371)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.onComplete(Promise.scala:379)
	at scala.concurrent.impl.Promise.transform(Promise.scala:33)
	at scala.concurrent.impl.Promise.transform$(Promise.scala:31)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.transform(Promise.scala:379)
	at scala.concurrent.Future.map(Future.scala:292)
	at scala.concurrent.Future.map$(Future.scala:292)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.map(Promise.scala:379)
	at scala.concurrent.Future$.apply(Future.scala:659)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$6(CrossValidator.scala:182)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$4(CrossValidator.scala:172)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$1(CrossValidator.scala:166)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.tuning.CrossValidator.fit(CrossValidator.scala:137)
	at org.example.SparkML.SparkMl01$.main(SparkMl01.scala:147)
	at org.example.SparkML.SparkMl01.main(SparkMl01.scala)
Caused by: java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.trees.BinaryLike
	at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)

全部源码以及pom文件

package org.example.SparkML


import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Binarizer, Bucketizer, ChiSqSelector, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col

/**
 * 数据挖掘的过程
 * 1.数据预处理
 * 2.特征转换(编码。。。)
 * 3.特征选择
 * 4.训练模型
 * 5.模型预测
 * 6.评估预测结果
 */
object SparkMl01 {
  def main(args: Array[String]): Unit = {
    // 定义spark对象
    val spark = SparkSession.builder().appName("兴趣预测").master("local").getOrCreate()
    import spark.implicits._
    val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")
    //1.数据预处理,补全空缺的年龄
    val ageNaDF = df.select("age").na.drop()
    val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
    val ageFilledDF = df.na.fill(mean,List("age"))
    //address为空的行直接删除
    val addressDf = ageFilledDF.na.drop()

    //对df的schema进行调整
    val formatDF = addressDf.select(
      col("id").cast("int"),
      col("hobby").cast("String"),
      col("sex").cast("String"),
      col("address").cast("String"),
      col("age").cast("Double"),
      col("height").cast("Double"),
      col("weight").cast("Double")
    )

    //2.特征转换
    //2.1 对年龄进行分桶处理
    //定义一个数组作为分桶的区间
    val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
    val bucketizerDF = new Bucketizer()
      .setInputCol("age")
      .setOutputCol("ageFeature")
      .setSplits(ageSplits)
      .transform(formatDF)


    //2.2 对身高做二值化处理
    val heightDF = new Binarizer()
      .setInputCol("height")
      .setOutputCol("heightFeature")
      .setThreshold(170) // 阈值
      .transform(bucketizerDF)

    //2.3 对体重做二值化处理
    val weightDF = new Binarizer()
      .setInputCol("weight")
      .setOutputCol("weightFeature")
      .setThreshold(65)
      .transform(heightDF)

    //2.4 对性别进行labelEncode转换
    val sexIndex = new StringIndexer()
      .setInputCol("sex")
      .setOutputCol("sexIndex")
      .fit(weightDF)
      .transform(weightDF)

    //2.5对家庭地址进行labelEncode转换
    val addIndex = new StringIndexer()
      .setInputCol("address")
      .setOutputCol("addIndex")
      .fit(sexIndex)
      .transform(sexIndex)

    //2.6对地址进行one-hot编码
    val addOneHot = new OneHotEncoder()
      .setInputCol("addIndex")
      .setOutputCol("addOneHot")
      .fit(addIndex)
      .transform(addIndex)

    //2.7对兴趣字段进行LabelEncode处理
    val hobbyIndexDF = new StringIndexer()
      .setInputCol("hobby")
      .setOutputCol("hobbyIndex")
      .fit(addOneHot)
      .transform(addOneHot)

    //2.8修改列名
    val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")


    //3 特征选择
    //3.1选择特征
    val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addOneHot"))
      .setOutputCol("features")

    //3.2特征进行规范化处理
    val scaler = new StandardScaler()
      .setInputCol("features")
      .setOutputCol("featureScaler")
      .setWithStd(true) // 是否使用标准差
      .setWithMean(false)  // 是否使用中位数


    // 特征筛选,使用卡方检验方法来做筛选
    val selector = new ChiSqSelector()
      .setFeaturesCol("featureScaler")
      .setLabelCol("label")
      .setOutputCol("featuresSelector")


    // 逻辑回归模型
    val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")

    // 构造pipeline
    val pipeline = new Pipeline()
      .setStages(Array(vectorAssembler,scaler,selector,lr))

    // 设置网络搜索最佳参数
    val params = new ParamGridBuilder()
      .addGrid(lr.regParam,Array(0.1,0.01))  //正则化参数
      .addGrid(selector.numTopFeatures,Array(5,10,5))  //设置卡方检验最佳特征数
      .build()

    // 设置交叉检验
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator())
      .setEstimatorParamMaps(params)
      .setNumFolds(5)

    // 模型训练
    val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))
    trainDF.show()
    testDF.show()
    val model = cv.fit(trainDF)

    //生成模型
//    val model = pipeline.fit(trainDF)
//    val prediction = model.transform(testDF)
//    prediction.show()



    // 模型预测
//    val preddiction = model.bestModel.transform(testDF)
//    preddiction.show()



    spark.stop()
  }
}

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>untitled</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.12.18</version>
        </dependency>


        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.12</artifactId>
            <version>3.0.0-preview2</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.12</artifactId>
            <version>3.1.2</version>
<!--            <scope>provided</scope>-->
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.12</artifactId>
            <version>3.0.0-preview2</version>
<!--            <scope>compile</scope>-->
        </dependency>


<!--        <dependency>-->
<!--            <groupId>mysql</groupId>-->
<!--            <artifactId>mysql-connector-java</artifactId>-->
<!--            <version>8.0.16</version>-->
<!--        </dependency>-->

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.12</artifactId>
            <version>3.5.0</version>
<!--            <scope>compile</scope>-->
        </dependency>




    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>2.4.1</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                    <mainClass>com.xxg.Main</mainClass>
                                </transformer>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>

        </plugins>
    </build>


</project>

你可能感兴趣的:(spark,机器学习,数据挖掘)