首先纠正一下标题,这个类不是spark的源码中的,而是scala的源码中的,但是在spark源码中经常用到它。稀疏向量,底层基于索引数组和值数组共同实现。该类的核心思想是用两个数组,一个记录原始向量中非零元素的值,另一个记录原始向量中非零元素在原始向量中的位置。一共有三个数据成员,size记录原始向量的长度,indices数组为索引数组,values数组为值数组,索引数组和值数组的长度必须一致。注意:一个普通的SparseVector向量和普通向量没有区别,只有在这个向量调用了该类的toSparse方法把向量本身做了压缩之后值数组才只是存储非零元素。下面是该类的源码,我在关键的地方都做了详细注释。
class SparseVector @Since("1.0.0") (
@Since("1.0.0") override val size: Int,
@Since("1.0.0") val indices: Array[Int],
@Since("1.0.0") val values: Array[Double]) extends Vector {
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")
override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
//转化为数组,其中包含所有元素,不是只转化向量中的非零元素
@Since("1.0.0")
override def toArray: Array[Double] = {
val data = new Array[Double](size)
var i = 0
val nnz = indices.length
while (i < nnz) {
data(indices(i)) = values(i)
i += 1
}
data
}
//由一个已经存在的稀疏向量实例构造另一个实例
@Since("1.1.0")
override def copy: SparseVector = {
new SparseVector(size, indices.clone(), values.clone())
}
private[spark] override def asBreeze: BV[Double] = new BSV[Double](indices, values, size)
//定义一个名为foreachActive的方法,foreachActive方法没有返回值,foreachActive方法的参数是一个函数f,这个函数有两个参数并且没有返回值。foreachActive方法的作用是在每个值数组中元素上执行函数f所定义的操作
@Since("1.6.0")
override def foreachActive(f: (Int, Double) => Unit): Unit = {
var i = 0
val localValuesSize = values.length
val localIndices = indices
val localValues = values
while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
}
override def equals(other: Any): Boolean = super.equals(other)
override def hashCode(): Int = {
var result: Int = 31 + size
//end赋值为值数组的长度
val end = values.length
var k = 0
var nnz = 0
//遍历值数组中的所有元素来计算hashCode
while (k < end && nnz < Vectors.MAX_HASH_NNZ) {
val v = values(k)
//如果值数组中的元素不为零
if (v != 0.0) {
//用变量i暂存不为零元素在原始向量中的位置
val i = indices(k)
result = 31 * result + i
val bits = java.lang.Double.doubleToLongBits(v)
result = 31 * result + (bits ^ (bits >>> 32)).toInt
nnz += 1
}
k += 1
}
result
}
//返回值数组的长度,这个长度也必然是索引数组的长度
@Since("1.4.0")
override def numActives: Int = values.length
//返回值数组中非零元素的个数
@Since("1.4.0")
override def numNonzeros: Int = {
var nnz = 0
values.foreach { v =>
if (v != 0.0) {
nnz += 1
}
}
nnz
}
//该方法把一个向量压缩
@Since("1.4.0")
override def toSparse: SparseVector = {
val nnz = numNonzeros
//如果值数组中的元素全部都是非零的,说明这个向量本来就不是稀疏向量,不能再压缩,不做任何处理直接返回
if (nnz == numActives) {
this
} else {
//创建一个长度为非零元素个数的Int类型和Double类型数组记录所有非零元素
val ii = new Array[Int](nnz)
val vv = new Array[Double](nnz)
var k = 0
//遍历原始向量,只存储所有非零元素,索引数组的第k个元素记录数值数组的第k个元素在原始向量中的位置
foreachActive { (i, v) =>
if (v != 0.0) {
ii(k) = i
vv(k) = v
k += 1
}
}
//返回压缩后的稀疏向量
new SparseVector(size, ii, vv)
}
}
//返回原始向量中值最大的那个元素在原始向量中的位置。若最大元素是0,那么返回第一个0元素的位置
@Since("1.5.0")
override def argmax: Int = {
if (size == 0) {
-1
} else {
// Find the max active entry.
var maxIdx = indices(0)
var maxValue = values(0)
var maxJ = 0
var j = 1
val na = numActives
//按照选择排序的思想查找值数组中最大的元素存储在maxValue变量中,并且将最大的元素对应的索引数组的下标存在maxJ中,把最大元素对应的原始向量的序号从索引数组中取出存在maxIdx中
while (j < na) {
val v = values(j)
if (v > maxValue) {
maxValue = v
maxIdx = indices(j)
maxJ = j
}
j += 1
}
//如果值数组中最大的元素是非正数,并且值数组的长度小于元素向量的长度(即元素向量中有值为零的元素),那么找到原始向量中第一个零元素的位置即是第一个最大元素的位置,因为最大元素是0
if (maxValue <= 0.0 && na < size) {
if (maxValue == 0.0) {
if (maxJ < maxIdx) {
var k = 0
while (k < maxJ && indices(k) == k) {
k += 1
}
maxIdx = k
}
} else {
//值数组中最大元素为负数
var k = 0
//找到索引数组中第一个值和下标不相等的那个元素,即为原始向量中的第一个值为零的元素在元素向量中的位置
while (k < na && indices(k) == k) {
k += 1
}
maxIdx = k
}
}
maxIdx
}
}
//slice方法相当于根据selectedIndices数组的元素对SparseVector实例进行转化,比如selectedIndices数组为{5,4,8,9,2},那么就是取出原始数组中的第5,4,8,9,2个元素,并按这个顺序生成新的SparseVetor实例
private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
var currentIdx = 0
//flatMap对于调用它的数组的每个元素执行一定操作,每个元素可能返回多个新元素。比如selectedIndices数组中元素为{1,2,3},第一趟查找索引数组中值为1的的元素的下标存在iIdx中,也就是说找到原始向量中位置编号为1的元素在索引数组中的下标(为了后面第三行代码由这个下标取出对应的数值)
val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
//调用java源码中的Arrays类的二分查找静态方法,找到索引数组中值为origIdx的元素的下标
val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
val i_v = if (iIdx >= 0) {
//取出数值数组中下标和被查找出来的索引数组的下标相同的元素,即取出元素向量中下标满足origIdx条件的那个元素,放到新的迭代器中
Iterator((currentIdx, this.values(iIdx)))
} else {
Iterator()
}
currentIdx += 1
i_v
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
}
@Since("1.6.0")
override def toJson: String = {
val jValue = ("type" -> 0) ~
("size" -> size) ~
("indices" -> indices.toSeq) ~
("values" -> values.toSeq)
compact(render(jValue))
}
@Since("2.0.0")
override def asML: newlinalg.SparseVector = {
new newlinalg.SparseVector(size, indices, values)
}
}