spark源码分析StatCounter以及用法

StatCounter
这是用于统计的一个类,在org.apache.spark.util包中
如果是RDD[Double]可以通过隐式转化DoubleRDDFunctions来获得一些额外的功能,就比如能产生这个对象的.stats

  def stats(): StatCounter = self.withScope {
    self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
  }

其他的就不拓展讲了,主要介绍一下StatCounter这个类

功能
这个类的描述是

 * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
 * numerically robust way. Includes support for merging two StatCounters. Based on Welford
 * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
 * for running variance.

通过描述我们可以知道这个类主要用于对一组数字进行追踪,主要用来统计级数、平均数和方差,其中包括对合并两个StatCounters的支持

构造器:

class StatCounter(values: TraversableOnce[Double]) extends Serializable {

}

其中参数类型为TraversableOnce,这是Traversable和Iterator的公共父类,这个类的描述是

 *  This trait exists primarily to eliminate code duplication between
 *  `Iterator` and `Traversable

所以这里我们可以传迭代器进来

属性

  private var n: Long = 0     // Running count of our values  count计数
  private var mu: Double = 0  // Running mean of our values  平均值
  private var m2: Double = 0  // Running variance numerator (sum of (x - mean)^2) 离差平方之和,除以计数则是方差
  private var maxValue: Double = Double.NegativeInfinity // Running max of our values 最大值
  private var minValue: Double = Double.PositiveInfinity // Running min of our values 最小值

  def count: Long = n //返回n 计数

  def mean: Double = mu //返回mu 平均值

  def sum: Double = n * mu //sum,求和,平均值与个数乘积

  def max: Double = maxValue //最大值

  def min: Double = minValue //最小值

初始化


 /** Initialize the StatCounter with no values.
    没有value情况下的构造器
  */
  def this() = this(Nil)

  /** Add a value into this StatCounter, updating the internal statistics.
    更新状态主要靠merge函数,这个函数接收了一个Double参数,
   */
  def merge(value: Double): StatCounter = {
    val delta = value - mu //离差或者差量: 该数与平均值的差值
    n += 1 //总数加1
    mu += delta / n //  原mu我们叫做mu1,新的叫mu2,那么mu1 = sum / count ,mu2 = (sum + value) / count + 1 ,所以两者之间的差为: (count * value - count * mu1) / (count * (count + 1)),进一步化简: (value - mu1) / (count + 1),而value - mu1 = delta,所以可以得到上述公式
    m2 += delta * (value - mu) // 此时注意的是mu在上一步中完成更新,其推导过程类似上一步,这里就不再展开推导了
    maxValue = math.max(maxValue, value) 
    minValue = math.min(minValue, value)
    //maxValue =  if(maxValue > value) maxValue else value
    //minValue = if(minValue < value) minVa  /** Clone this StatCounter */
  def copy(): StatCounter = {
    val other = new StatCounter
    other.n = n
    other.mu = mu
    other.m2 = m2
    other.maxValue = maxValue
    other.minValue = minValue
    other
  }lue else value
    this //返回本身
  }

针对TraversableOnce对象,则有这样的merge方法:

/** Add multiple values into this StatCounter, updating the internal statistics.
多个元素会先调用foreach,然后分别取更新状态
*/
def merge(values: TraversableOnce[Double]): StatCounter = {
values.foreach(v => merge(v))
this
}

而这个merge方法会在类的初始化的时候被调用:

merge(values)
1

与其他StatCounter的merge:
  /** Merge another StatCounter into this one, adding up the internal statistics.
  合并多个StatCounter
   */
  def merge(other: StatCounter): StatCounter = {
    if (other == this) {//如果两个完全一致
      merge(other.copy())  // Avoid overwriting fields in a weird order
    } else {
      if (n == 0) {//如果该StatCounter没有元素,则直接将另一个的拷贝过来来覆盖这些参数
        mu = other.mu
        m2 = other.m2
        n = other.n
        maxValue = other.maxValue
        minValue = other.minValue
      } else if (other.n != 0) {//如果另一个的计数不等于0,此时两个Counter都有数据
        val delta = other.mu - mu //两个Counter平均值之差
        if (other.n * 10 < n) { // 此时比较两个Counter计数结果,本质就是求两组数据的平方差之和,但需要看以谁为基准求
          mu = mu + (delta * other.n) / (n + other.n) //这个推导就是将第一个mu看成 (mu * n) / n ,然后合并同类项,最终可以得到最后那个else分支的版本
        } else if (n * 10 < other.n) {
          mu = other.mu - (delta * n) / (n + other.n) 
        } else {
          mu = (mu * n + other.mu * other.n) / (n + other.n) // 两个sum相加并将两个count相加,求平均值
        }
        m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
        n += other.n
        maxValue = math.max(maxValue, other.maxValue)
        minValue = math.min(minValue, other.minValue)
      }
      this
    }
  }


 /** Clone this StatCounter 
 完成了对该StatCounter的拷贝*/
  def copy(): StatCounter = {
    val other = new StatCounter
    other.n = n
    other.mu = mu
    other.m2 = m2
    other.maxValue = maxValue
    other.minValue = minValue
    other
  }

其他
  /** Return the variance of the values. 方差,就是离差平方和除以计数,前提是n不为0*/
  def variance: Double = {
    if (n == 0) {
      Double.NaN
    } else {
      m2 / n
    }
  }

  /**
   * Return the sample variance, which corrects for bias in estimating the variance by dividing
   * by N-1 instead of N. 样本方差
   */
  def sampleVariance: Double = {
    if (n <= 1) {
      Double.NaN
    } else {
      m2 / (n - 1)
    }
  }

  /** Return the standard deviation of the values. 标准差*/
  def stdev: Double = math.sqrt(variance)

  /**
   * Return the sample standard deviation of the values, which corrects for bias in estimating the
   * variance by dividing by N-1 instead of N. 样本标准差
   */
  def sampleStdev: Double = math.sqrt(sampleVariance)

  /**
  重写toString方法
  **/
  override def toString: String = {
    "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
  }

伴生对象
object StatCounter {
  /** Build a StatCounter from a list of values. 这边是以TraversableOnce为参数类型*/
  def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)

  /** Build a StatCounter from a list of values passed as variable-length arguments. 这边是Double的List */
  def apply(values: Double*): StatCounter = new StatCounter(values)
}

代码中使用:

println(sc.parallelize(Seq(1.0,2.2,3.1)).stats())
1
输出:
(count: 3, mean: 2.100000, stdev: 0.860233, max: 3.100000, min: 1.000000)


你可能感兴趣的:(spark源码分析StatCounter以及用法)