StatCounter
这是用于统计的一个类,在org.apache.spark.util包中
如果是RDD[Double]可以通过隐式转化DoubleRDDFunctions来获得一些额外的功能,就比如能产生这个对象的.stats
def stats(): StatCounter = self.withScope {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
}
其他的就不拓展讲了,主要介绍一下StatCounter这个类
功能
这个类的描述是
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
* numerically robust way. Includes support for merging two StatCounters. Based on Welford
* and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
* for running variance.
通过描述我们可以知道这个类主要用于对一组数字进行追踪,主要用来统计级数、平均数和方差,其中包括对合并两个StatCounters的支持
构造器:
class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
其中参数类型为TraversableOnce,这是Traversable和Iterator的公共父类,这个类的描述是
* This trait exists primarily to eliminate code duplication between
* `Iterator` and `Traversable
所以这里我们可以传迭代器进来
属性
private var n: Long = 0 // Running count of our values count计数
private var mu: Double = 0 // Running mean of our values 平均值
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) 离差平方之和,除以计数则是方差
private var maxValue: Double = Double.NegativeInfinity // Running max of our values 最大值
private var minValue: Double = Double.PositiveInfinity // Running min of our values 最小值
def count: Long = n //返回n 计数
def mean: Double = mu //返回mu 平均值
def sum: Double = n * mu //sum,求和,平均值与个数乘积
def max: Double = maxValue //最大值
def min: Double = minValue //最小值
初始化
/** Initialize the StatCounter with no values.
没有value情况下的构造器
*/
def this() = this(Nil)
/** Add a value into this StatCounter, updating the internal statistics.
更新状态主要靠merge函数,这个函数接收了一个Double参数,
*/
def merge(value: Double): StatCounter = {
val delta = value - mu //离差或者差量: 该数与平均值的差值
n += 1 //总数加1
mu += delta / n // 原mu我们叫做mu1,新的叫mu2,那么mu1 = sum / count ,mu2 = (sum + value) / count + 1 ,所以两者之间的差为: (count * value - count * mu1) / (count * (count + 1)),进一步化简: (value - mu1) / (count + 1),而value - mu1 = delta,所以可以得到上述公式
m2 += delta * (value - mu) // 此时注意的是mu在上一步中完成更新,其推导过程类似上一步,这里就不再展开推导了
maxValue = math.max(maxValue, value)
minValue = math.min(minValue, value)
//maxValue = if(maxValue > value) maxValue else value
//minValue = if(minValue < value) minVa /** Clone this StatCounter */
def copy(): StatCounter = {
val other = new StatCounter
other.n = n
other.mu = mu
other.m2 = m2
other.maxValue = maxValue
other.minValue = minValue
other
}lue else value
this //返回本身
}
针对TraversableOnce对象,则有这样的merge方法:
/** Add multiple values into this StatCounter, updating the internal statistics.
多个元素会先调用foreach,然后分别取更新状态
*/
def merge(values: TraversableOnce[Double]): StatCounter = {
values.foreach(v => merge(v))
this
}
而这个merge方法会在类的初始化的时候被调用:
merge(values)
1
与其他StatCounter的merge:
/** Merge another StatCounter into this one, adding up the internal statistics.
合并多个StatCounter
*/
def merge(other: StatCounter): StatCounter = {
if (other == this) {//如果两个完全一致
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
if (n == 0) {//如果该StatCounter没有元素,则直接将另一个的拷贝过来来覆盖这些参数
mu = other.mu
m2 = other.m2
n = other.n
maxValue = other.maxValue
minValue = other.minValue
} else if (other.n != 0) {//如果另一个的计数不等于0,此时两个Counter都有数据
val delta = other.mu - mu //两个Counter平均值之差
if (other.n * 10 < n) { // 此时比较两个Counter计数结果,本质就是求两组数据的平方差之和,但需要看以谁为基准求
mu = mu + (delta * other.n) / (n + other.n) //这个推导就是将第一个mu看成 (mu * n) / n ,然后合并同类项,最终可以得到最后那个else分支的版本
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n) // 两个sum相加并将两个count相加,求平均值
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
maxValue = math.max(maxValue, other.maxValue)
minValue = math.min(minValue, other.minValue)
}
this
}
}
/** Clone this StatCounter
完成了对该StatCounter的拷贝*/
def copy(): StatCounter = {
val other = new StatCounter
other.n = n
other.mu = mu
other.m2 = m2
other.maxValue = maxValue
other.minValue = minValue
other
}
其他
/** Return the variance of the values. 方差,就是离差平方和除以计数,前提是n不为0*/
def variance: Double = {
if (n == 0) {
Double.NaN
} else {
m2 / n
}
}
/**
* Return the sample variance, which corrects for bias in estimating the variance by dividing
* by N-1 instead of N. 样本方差
*/
def sampleVariance: Double = {
if (n <= 1) {
Double.NaN
} else {
m2 / (n - 1)
}
}
/** Return the standard deviation of the values. 标准差*/
def stdev: Double = math.sqrt(variance)
/**
* Return the sample standard deviation of the values, which corrects for bias in estimating the
* variance by dividing by N-1 instead of N. 样本标准差
*/
def sampleStdev: Double = math.sqrt(sampleVariance)
/**
重写toString方法
**/
override def toString: String = {
"(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
}
伴生对象
object StatCounter {
/** Build a StatCounter from a list of values. 这边是以TraversableOnce为参数类型*/
def apply(values: TraversableOnce[Double]): StatCounter = new StatCounter(values)
/** Build a StatCounter from a list of values passed as variable-length arguments. 这边是Double的List */
def apply(values: Double*): StatCounter = new StatCounter(values)
}
代码中使用:
println(sc.parallelize(Seq(1.0,2.2,3.1)).stats())
1
输出:
(count: 3, mean: 2.100000, stdev: 0.860233, max: 3.100000, min: 1.000000)