Spark数据统计指标计算

前言

在机器学习训练模型时,如果遇到空值,一般有三种处理方法,分别是删除法、替换法和插补法。删除法是指当缺失的观测比例非常低时(如5%以内),直接删除存在缺失的观测,或者当某些变量的缺失比例非常高时(如85%以上),直接删除这些缺失的变量;替换法是指用某种常数直接替换那些缺失值,例如,对连续变量而言,可以使用均值或中位数替换,对于离散变量,可以使用众数替换;插补法是指根据其他非缺失的变量或观测来预测缺失值,常见的插补法有回归插补法、K近邻插补法、拉格朗日插补法等。
用Python很容易就能得到均值、众数、中位数等等指标、而Spark要统计这些指标,却没有现成的API,接下来将一一进行介绍,如何计算这些指标

数据准备

1,11
2,12
3,12
3,13

入口代码

val spark = SparkSession
      .builder
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._
    import org.apache.spark.sql.functions._
    val df = spark.read.textFile("./data/mode")
      .map(_.split(","))
      .map(x => (x(0), x(1)))
      .toDF("col1", "col2")
      .cache()

describe

DataFrame的describe方法,可以统计出数据特征的总数、平均值、方差、最小值、最大值

df.describe().show()

结果:

+-------+------------------+------------------+
|summary|              col1|              col2|
+-------+------------------+------------------+
|  count|                 4|                 4|
|   mean|              2.25|              12.0|
| stddev|0.9574271077563381|0.8164965809277263|
|    min|                 1|                11|
|    max|                 3|                13|
+-------+------------------+------------------+

那么,如果我们现在想要得到col1列的均值怎么算?
这时候可以借助collect算子来实现:

val col1_mean = df.describe("col1").collect()(1).get(1)
println(col1_mean)

结果:

4

看起来很完美,但是官方并不推荐,他们是这样说的:

This function is meant for exploratory data analysis, as we make no guarantee about the backward compatibility of the schema of the resulting Dataset. If you want to programmatically compute summary statistics, use the agg function instead.
这句话大概意思就是,人家研发的describe方法目的是用于数据探索性分析,如果你想以编程方式计算汇总统计信息,改为使用agg函数,否则到了后面的版本出现不兼容的问题,你别来找我

推荐的写法:

import org.apache.spark.sql.functions._
df.agg(mean($"col1").alias("col1_mean")).show()

结果:

+---------+
|col1_mean|
+---------+
|     2.25|
+---------+

同理,要得到这个值,需要借助collect算子

import org.apache.spark.sql.functions._
val col1_mean = df.agg(mean($"col1").alias("col1_mean")).collect()(0).get(0)
println(col1_mean)

结果:

2.25

其他四种获取方式类似,这里不再赘述

求众数

定义:是一组数据中出现次数最多的数值,叫众数

val v1 = df.select("col1")
          .groupBy("col1")
          .count()
          .orderBy(df("col1").desc)
          .collect()(0).get(0)
println(v1)

结果:

3

当然也可以用sql:

df.createTempView(viewName = "view1")
        import spark.sql
        val v2 = sql(sqlText = "select col1,count(*) as ct1 from view1 group by col1 order by ct1 desc")
          .select("col1").collect()(0).get(0)
println(v2)
3

中位数

定义:对于有限的数集,可以通过把所有观察值高低排序后找出正中间的一个作为中位数。如果观察值有偶数个,通常取最中间的两个数值的平均数作为中位数

这个难度比求众数大了一些,原因:

  • 求众数,并不要在意是否有脏数据(因为如果众数为脏数据,那么这个字段基本上可以删掉了),但是中位数就不一样,很有可能出现脏数据,因此,再求中位数的时候,得剔除掉脏数据(包括:空值、异常值等等)
  • 中位数先得求出总数,这个倒不难,难点在:比如有1万多条数据,就算你知道中位数就是第五千条,你怎么能把它找出来

至于剔除脏数据,这里就不再演示了,直接看计算中位数的过程:已经封装为方法,直接用即可

/**
    * 计算中位数的方法
    *
    * @param spark    SparkSession
    * @param df       DataFrame
    * @param col_name the col which to calculate the middle
    * @return middle
    */
  def calc_middle(spark: SparkSession, df: DataFrame, col_name: String): Double = {
    import org.apache.spark.sql.functions._
    import spark.implicits._
    //统计总行数
    val c = df.count()
    // 增加自增长的列
    val df_add_id = df.withColumn("id", monotonically_increasing_id() + 1)
    if (c % 2 == 0) { // 如果是偶数
      val bf = df_add_id.select(col_name).where($"id" === c / 2)
        .collect()(0).get(0).toString.trim.toDouble
      val af = df_add_id.select(col_name).where($"id" === c / 2 + 1)
        .collect()(0).get(0).toString.trim.toDouble
      (bf + af) / 2
    } else {
      df_add_id.select(col_name)
        .where($"id" === (c + 1) / 2)
        .collect()(0).get(0).toString.trim.toDouble
    }
  }

测试数据(总行数为偶数):

1,11
2,12
3,12
3,13

测试:

 println(calc_middle(spark, df, "col1"))

结果:

2.5

测试数据(总行数为奇数):

1,11
2,12
3,12
3,13
4,14

测试:

 println(calc_middle(spark, df, "col1"))

结果:

3.0

后记

看来虽然spark求中位数、众数没有现成的方法,不过写起来也就几行代码而已

你可能感兴趣的:(Spark)