spark word2vec 源码详细解析

spark word2vec 源码详细解析

  • 简单介绍spark word2vec
    • skip-gram 层次softmax版本的源码解析
    • word2vec 的原理 只需要看层次哈弗曼树skip-gram那部分
    • skip-gram negetive sample 的版本源码解析:

简单介绍spark word2vec

Word2Vec creates vector representation of words in a text corpus.
The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary.
The vector representation can be used as features in natural language processing and machine learning algorithms.
We used skip-gram model in our implementation and hierarchical softmax method to train the model. The variable names in the implementation matches the original C implementation.
For original C implementation, see https://code.google.com/p/word2vec/
For research papers, see
Efficient Estimation of Word Representations in Vector Space paper1
and
Distributed Representations of Words and Phrases and their Compositionality. paper2
word2vec算法创建了关于语料库里面词的代表————词向量。
该算法首先从语料库构建词汇表,然后学习词汇表中单词的向量表示。 向量表示可用作自然语言处理和机器学习算法中的特征。 sparkMLLIB只实现了skip-gram模型,并使用分层softmax方法来训练模型。spark的代码实现参考原始word2vecC语言代码一致。原始C语言实现见:https://code.google.com/p/word2vec/。相关研究论文见:Efficient Estimation of Word Representations in Vector Space和Distributed Representations of Words and Phrases and their Compositionality。

skip-gram 层次softmax版本的源码解析

package org.apache.spark.mllib.feature
import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
/**
*  Entry in vocabulary   定义词典的属性类   复习:scala的class类别
*/
private case class VocabWord(
  var word: String,   //词
  var cn: Int,        //频次
  var point: Array[Int],   // ARRAY 存的是这个词[叶子结点]的从根节点到叶子节点的路径经过的节点
  var code: Array[Int],   //记录Huffman编码
  var codeLen: Int        //code长度,路径长度 ,存储到达该叶子结点,要经过多少个结点
)
本文只实现skip-gram hierarchical softmax 部分,参照C语言实现的代码:https://code.google.com/p/word2vec/
参照两篇论文:Efficient Estimation of Word Representations in Vector Space & Distributed Representations of Words and Phrases and their Compositionality
@Since("1.1.0")
class Word2Vec extends Serializable with Logging {
//默认参数
  private var vectorSize = 100  //训练vector的长度
  private var learningRate = 0.025  //训练时的学习率
  private var numPartitions = 1   //分区数
  private var numIterations = 1   //迭代次数
  private var seed = Utils.random.nextLong()  //随机种子
  private var minCount = 5   //词的最小出现频次
  private var maxSentenceLength = 1000  //句子的长度

//如果大于maxSentenceLength 句子的长度,将会截断为多个块。
  /**
   * Sets the maximum length (in words) of each sentence in the input data.
   * Any sentence longer than this threshold will be divided into chunks of
   * up to `maxSentenceLength` size (default: 1000)
   */
  @Since("2.0.0")
  def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
    require(maxSentenceLength > 0,
      s"Maximum length of sentences must be positive but got ${maxSentenceLength}")
    this.maxSentenceLength = maxSentenceLength
    this
  }
  /**
   * Sets vector size (default: 100).
   */
  @Since("1.1.0")
  def setVectorSize(vectorSize: Int): this.type = {
    require(vectorSize > 0,
      s"vector size must be positive but got ${vectorSize}")
    this.vectorSize = vectorSize
    this
  }


  /**
   * Sets initial learning rate (default: 0.025).
   */
  @Since("1.1.0")
  def setLearningRate(learningRate: Double): this.type = {
    require(learningRate > 0,
      s"Initial learning rate must be positive but got ${learningRate}")
    this.learningRate = learningRate
    this
  }


  /**
   * Sets number of partitions (default: 1). Use a small number for accuracy. 
   * //设置少数分区有利于准确性
   */
  @Since("1.1.0")
  def setNumPartitions(numPartitions: Int): this.type = {
    require(numPartitions > 0,
      s"Number of partitions must be positive but got ${numPartitions}")
    this.numPartitions = numPartitions
    this
  }


  /**
   * Sets number of iterations (default: 1), which should be smaller than or equal to number of 
   * partitions. 
   *  //设置迭代次数,要小于或者等于分区数
   */
  @Since("1.1.0")
  def setNumIterations(numIterations: Int): this.type = {
    require(numIterations >= 0,
      s"Number of iterations must be nonnegative but got ${numIterations}")
    this.numIterations = numIterations
    this
  }


  /**
   * Sets random seed (default: a random long integer).
   */
  @Since("1.1.0")
  def setSeed(seed: Long): this.type = {
    this.seed = seed
    this
  }


  /**
   * Sets the window of words (default: 5) 
   * //根据单个文本的长度合理设置,目前针对于标题40个字,设置为5
   */
  @Since("1.6.0")
  def setWindowSize(window: Int): this.type = {
    require(window > 0,
      s"Window of words must be positive but got ${window}")
    this.window = window
    this
  }


  /**
   * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
   * model's vocabulary (default: 5).
   * //根据文本的词的频次分布设置,保证覆盖大多数的文本。
   */
  @Since("1.3.0")
  def setMinCount(minCount: Int): this.type = {
    require(minCount >= 0,
      s"Minimum number of times must be nonnegative but got ${minCount}")
    this.minCount = minCount
    this
  }


  private val EXP_TABLE_SIZE = 1000
  private val MAX_EXP = 6
  private val MAX_CODE_LENGTH = 40


  /** context words from [-window, window] */  
  //滑动窗口以中心词的左右各+-window选词。
  private var window = 5


  private var trainWordsCount = 0L
  private var vocabSize = 0
*********transient 解释:
我们都知道一个对象只要实现了Serilizable接口,这个对象就可以被序列化,java的这种序列化模式为开发者提供了很多便利,我们可以不必关系具体序列化的过程,只要这个类实现了Serilizable接口,这个类的所有属性和方法都会自动序列化。
然而在实际开发过程中,我们常常会遇到这样的问题,这个类的有些属性需要序列化,而其他属性不需要被序列化,打个比方,如果一个用户有一些敏感信息(如密码,银行卡号等),为了安全起见,不希望在网络操作(主要涉及到序列化操作,本地序列化缓存也适用)中被传输,这些信息对应的变量就可以加上transient关键字。换句话说,这个字段的生命周期仅存于调用者的内存中而不会写到磁盘里持久化。
总之,java的transient关键字为我们提供了便利,你只需要实现Serilizable接口,将不需要序列化的属性前添加关键字transient,序列化对象的时候,这个属性就不会序列化到指定的目的地中。
*********transient 解释:
  @transient private var vocab: Array[VocabWord] = null
  @transient private var vocabHash = mutable.HashMap.empty[String, Int]

********************************************************************************************************************
from :org.apache.spark.ml.feature.Word2Vec#fit
override def fit(dataset: Dataset[_]): Word2VecModel = {
  transformSchema(dataset.schema, logging = true)
  val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
  val wordVectors = new feature.Word2Vec()
    .setLearningRate($(stepSize))
    .setMinCount($(minCount))
    .setNumIterations($(maxIter))
    .setNumPartitions($(numPartitions))
    .setSeed($(seed))
    .setVectorSize($(vectorSize))
    .setWindowSize($(windowSize))
    .setMaxSentenceLength($(maxSentenceLength))
    .fit(input)
  copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
*********************************************************************************************************************
//dataset来自上面的input,里面是:Seq[String]
  private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = {  //构建每个词的类
    val words = dataset.flatMap(x => x)  //把所有的词压平,统计词频
    vocab = words.map(w => (w, 1))
      .reduceByKey(_ + _)
      .filter(_._2 >= minCount)  //过滤词频大于minCount的词
      .map(x => VocabWord(
        x._1,
        x._2,
        new Array[Int](MAX_CODE_LENGTH),
        new Array[Int](MAX_CODE_LENGTH),
        0))
      .collect()
      .sortWith((a, b) => a.cn > b.cn)  //按频数从大到小排序

    vocabSize = vocab.length
    require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " +
      "the setting of minCount, which could be large enough to remove all your words in sentences.")


    var a = 0
    while (a < vocabSize) {
      vocabHash += vocab(a).word -> a   //@transient private var vocabHash = mutable.HashMap.empty[String, Int],【词,词频】  生成hashMap(K:word,V:a)--> 对词典中所有元素进行映射,方便查找
      trainWordsCount += vocab(a).cn    //训练词的个数统计
      a += 1
    }
    logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
  }

//创建sigmoid函数查询表
  private def createExpTable(): Array[Float] = {
    val expTable = new Array[Float](EXP_TABLE_SIZE)
    var i = 0
    while (i < EXP_TABLE_SIZE) {
      val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
      expTable(i) = (tmp / (tmp + 1.0)).toFloat
      i += 1
    }
    expTable
  }

//构造哈夫曼树
  private def createBinaryTree(): Unit = {
    val count = new Array[Long](vocabSize * 2 + 1)  //二叉树中所有的结点
    val binary = new Array[Int](vocabSize * 2 + 1)  //设置每个结点的Huffman编码:左1,右0
    val parentNode = new Array[Int](vocabSize * 2 + 1)  //存储每个结点的父节点
    val code = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的Huffman编码
    val point = new Array[Int](MAX_CODE_LENGTH)  //存储每个叶子结点的路径(经历过哪些结点)
    var a = 0
    while (a < vocabSize) { //节点 0~vocabSize-1  赋值为该节点词的频次  左边都是叶子结点
      count(a) = vocab(a).cn
      a += 1
    }
    while (a < 2 * vocabSize) {  //节点 vocabSize~2*vocabSize-1  赋值为1e9  右边都是父节点
      count(a) = 1e9.toInt
      a += 1
    }
    var pos1 = vocabSize - 1
    var pos2 = vocabSize

//min1i和min2i是左右节点
    var min1i = 0
    var min2i = 0


    a = 0
    while (a < vocabSize - 1) {
      if (pos1 >= 0) {
        if (count(pos1) < count(pos2)) {
          min1i = pos1
          pos1 -= 1
        } else {
          min1i = pos2
          pos2 += 1
        }
      } else {
        min1i = pos2
        pos2 += 1
      }
      if (pos1 >= 0) {
        if (count(pos1) < count(pos2)) {
          min2i = pos1
          pos1 -= 1
        } else {
          min2i = pos2
          pos2 += 1
        }
      } else {
        min2i = pos2
        pos2 += 1
      }
      count(vocabSize + a) = count(min1i) + count(min2i)   //从三个点里面找到和最小的两个点
      parentNode(min1i) = vocabSize + a    //父节点
      parentNode(min2i) = vocabSize + a    //父节点
      binary(min2i) = 1          //定义右子树为1
      a += 1
    }
    // Now assign binary code to each vocabulary word
    var i = 0
    a = 0
    while (a < vocabSize) {
      var b = a
      i = 0
      while (b != vocabSize * 2 - 2) {  //哈弗曼树一共有2n-1个节点,所以vocabSize*2-2指的是根节点,遍历a二叉树路径上的每个节点,除了根节点
        code(i) = binary(b)         //第b个结点的Huffman编码是0 or 1
        point(i) = b                //存储路径,经过b结点
        i += 1
        b = parentNode(b)          //按照路径去找下一个节点,遍历b的下个节点
      }
      vocab(a).codeLen = i         //存储到达叶子结点a,要经过多少个结点
      vocab(a).point(0) = vocabSize - 2 //每个词的point(0)都是一样的为vocabSize-2,这个是根节点,在这里哈弗曼树已经建立完成了,point记录的是叶子结点a的从根节点以来的路径,因为哈弗曼树所有词的节点是叶子结点,从根节点到叶子节点上的路径都是中间节点如图一所示的,路径里面的节点都减了vocabSize,因为中间节点是vocabSize-1个,所以又都放在0到vocabSize-1的范围了。
      b = 0
      while (b < i) {        //遍历a二叉树路径上的每个节点
        vocab(a).code(i - b - 1) = code(b)   //根据上一步的结果,对节点a的哈夫曼编码赋值
        vocab(a).point(i - b) = point(b) - vocabSize  //根据上一步的结果,对节点a的路径节点进行赋值
        b += 1
      }
      a += 1    //下一个词
    }
  }


  /**
   * Computes the vector representation of each word in vocabulary.
   * @param dataset an RDD of sentences,
   *                each sentence is expressed as an iterable collection of words
   * @return a Word2VecModel
   */
  @Since("1.1.0")
  def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = {


    learnVocab(dataset)         //构建词汇类


    createBinaryTree()          //构建 Huffman 树


    val sc = dataset.context


    val expTable = sc.broadcast(createExpTable())   //广播sigmod查询表
    val bcVocab = sc.broadcast(vocab)               //广播词汇类
    val bcVocabHash = sc.broadcast(vocabHash)       //广播词 词索引
    try {
      doFit(dataset, sc, expTable, bcVocab, bcVocabHash)  
    } finally {
      expTable.destroy(blocking = false)   //销毁广播变量
      bcVocab.destroy(blocking = false)
      bcVocabHash.destroy(blocking = false)
    }
  }


  private def doFit[S <: Iterable[String]](
    dataset: RDD[S], sc: SparkContext,
    expTable: Broadcast[Array[Float]],
    bcVocab: Broadcast[Array[VocabWord]],
    bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
    // each partition is a collection of sentences,
    // will be translated into arrays of Index integer
    val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>  //RDD[S] S为RDD里面最细粒度的数据结构,里面装的就是这个类型S的数据
      // Each sentence will map to 0 or more Array[Int]
      sentenceIter.flatMap { sentence =>
        // Sentence of words, some of which map to a word index
        val wordIndexes = sentence.flatMap(bcVocabHash.value.get) // flatMap对句子中每个词得到index,得到每个句子每个词的index
        // break wordIndexes into trunks of maxSentenceLength when has more
        wordIndexes.grouped(maxSentenceLength).map(_.toArray)  //如果有的句子的长度大于1000,就给它分组为1000单位,并是array | wordIndexes是个Iterable[Int]格式利用grouped函数对其分组。
           wordIndexes.grouped(maxSentenceLength)返回的是:Iterator[Array[Int]]
      }
    }


   //val newSentences = sentences.repartition(numPartitions).cache()   //按照给定的分区数,进行重分区  并且全部cache

//可以改为:
//todo 更改存储方式
    val newSentences = sentences.repartition(numPartitions).cache()
    //todo 对sentence进行checkpoint
    newSentences.sparkContext.setCheckpointDir("hdfs://ns4/user/dd_edw/tmp.db/item_relationship/item_embedding/graph_embedding_rdwalk")
    newSentences.checkpoint()
    newSentences.count
    bcVocabHash.destroy(blocking = false) //TODO 用完了 需要进行释放 销毁

//    val newSentences = sentences.repartition(numPartitions).persist(StorageLevel.MEMORY_AND_DISK_SER)
//    val newSentences = sentences.repartition(numPartitions).persist(StorageLevel.DISK_ONLY)

    val initRandom = new XORShiftRandom(seed)                         //


    if (vocabSize.toLong * vectorSize >= Int.MaxValue) {   //如果词汇量*词向量长度 大于或等于 INT最大值 就抛出异常
      throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +
        " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +
        "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue`.")
    }


    val syn0Global =
      Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)   //初始化叶子节点,分词向量随机设置初始值
    val syn1Global = new Array[Float](vocabSize * vectorSize)                                   //初始化非叶子结点,参数向量设置初始值为0
    val totalWordsCounts = numIterations * trainWordsCount + 1                                  //迭代次数*所有分词的个数 +1 
    var alpha = learningRate                                                                    //学习率


    for (k <- 1 to numIterations) {   //开始迭代
      val bcSyn0Global = sc.broadcast(syn0Global)     
      val bcSyn1Global = sc.broadcast(syn1Global)
      val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount //已经迭代过的词数


      val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
        val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
        val syn0Modify = new Array[Int](vocabSize)
        val syn1Modify = new Array[Int](vocabSize)
        /**
        def foldLeft[B](z: B)(op: (B, A) => B): B = {
          var result = z
          this foreach (x => result = op(result, x))
          result
        }
        */
        val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { //{}里面是OP[具体操作],初始值是(bcSyn0Global.value, bcSyn1Global.value, 0L, 0L),然后在每个分区里面串行运行,x是case ((syn0, syn1, lastWordCount, wordCount), sentence),最终结果是:(syn0, syn1, lwc, wc) 和Z同种类型。最后的结果(syn0, syn1, lwc, wc)总是更新存在的。总是赋值给B类型。最后结果也是B,B就是(syn0, syn1, lwc, wc)类型的数据。iter每迭代一次sentence就会更新一次B
          case ((syn0, syn1, lastWordCount, wordCount), sentence) =>  //每个分区里面的每个sentence
            var lwc = lastWordCount  //每次迭代的最新的
            var wc = wordCount
            if (wordCount - lastWordCount > 10000) { //当句子迭代10000个词的时候。每迭代10000词的时候就更新一下alpha
              lwc = wordCount   //更改上次词数
              alpha = learningRate *
                (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) /
                  totalWordsCounts)   //随着wordCount变大,alpha变小
              if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001  //当小于learningRate * 0.0001时候,直接等于learningRate * 0.0001
              logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " +
                s"alpha = $alpha")
            }
            wc += sentence.length //wc就是上次的wordCount,一直累加句子的长度。
            var pos = 0
            while (pos < sentence.length) {  //开始迭代,一个sentence中的pos位置的词,pos从0开始
              val word = sentence(pos)
              val b = random.nextInt(window)  //b是window内的随机数
              // Train Skip-gram
              var a = b
              while (a < window * 2 + 1 - b) {  //因为开始a = b ,从b开始到 window * 2 + 1 - b,也就是取pos词左右window - b 个词,迭代pos附近的窗口:window - b
                if (a != window) { //当a不是中心词
                  val c = pos - window + a   //pos位置的词pos-(window - a)[真实位置]
                  if (c >= 0 && c < sentence.length) {  //pos的左右位置迭代取值可能是负的或者超出句子长度,限定范围
                    val lastWord = sentence(c)    //该词的index
                    val l1 = lastWord * vectorSize  //syn0的index
                    val neu1e = new Array[Float](vectorSize) //相当于公式里面的e,就是x的梯度迭代项
                    // Hierarchical softmax
                    var d = 0
                    while (d < bcVocab.value(word).codeLen) {  //迭代中心词的路径哈夫曼二分类
                      val inner = bcVocab.value(word).point(d)  //路径上节点index
                      val l2 = inner * vectorSize               //syn1对应的index
                      // Propagate hidden -> output    blas.sdot函数解释:sdot(int n, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sx .* sy,并且sx[_sx_offset,incx*n + _sx_offset],sy[_sy_offset,incy*n + _sy_offset]
                      var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)   //向量点乘,syn0 .* syn1 , syn0[l1,l1+1*vectorSize],syn1[l2,l2+1*vectorSize]
                      if (f > -MAX_EXP && f < MAX_EXP) {                    
                        val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
                        f = expTable.value(ind)   //索引到sigmod函数表的值
                        val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat   //梯度
                        blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)   //neu1e = g * syn1 + neu1e  blas.saxpy函数解释:saxpy(int n, float sa, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),结果是:sy= sa*sx+sy,并且sx[_sx_offset,_sx_offset+incx*n],sy[_sy_offset,_sy_offset+incy*n]
                        blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)   //syn1 = g * syn0 + syn1
                        syn1Modify(inner) += 1          //记录参数向量里面的点被更新次数
                      }
                      d += 1
                    }
                    blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)  //syn0 = 1.0f * neu1e + syn0   a的情况下,迭代完成中心词pos附近的一个词的参数向量和词向量
                    syn0Modify(lastWord) += 1
                  }
                }
                a += 1
              }
              pos += 1   //循环到这个句子的下一个中心词
            }
            (syn0, syn1, lwc, wc)
        }
        val syn0Local = model._1   //syn0 为叶子结点向量,即分词向量
        val syn1Local = model._2   //syn1 为非叶子结点向量,即参数向量
        // Only output modified vectors.   Iterator.tabulate函数: Creates an iterator producing the values of a given function over a range of integer values starting from 0.
        Iterator.tabulate(vocabSize) { index =>
          if (syn0Modify(index) > 0) {
            Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
          } else {
            None
          }
        }.flatten ++ Iterator.tabulate(vocabSize) { index =>
          if (syn1Modify(index) > 0) {
            Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
          } else {
            None
          }
        }.flatten    //得到n个词向量的结果,n-1个中间节点的向量结果,两个结果(index,array)拼接起来,并且中间参数节点向量的index 从vocabSize开始编号
      }
      val synAgg = partial.reduceByKey { case (v1, v2) =>   //注意partial是所有分区内部的结果,按照同样的index下的array进行聚合,直接把所有分区的结果暴力累加
          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
          v1
      }.collect()
      var i = 0
      while (i < synAgg.length) {  //分别得到分词向量和参数向量
        val index = synAgg(i)._1
        if (index < vocabSize) {
          Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
        } else {
          Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
        }
        i += 1
      }
      bcSyn0Global.destroy(false)
      bcSyn1Global.destroy(false)
    }
    newSentences.unpersist()


    val wordArray = vocab.map(_.word)
    new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)  //得到分词向量
  }


  /**
   * Computes the vector representation of each word in vocabulary (Java version).
   * @param dataset a JavaRDD of words
   * @return a Word2VecModel
   */
  @Since("1.1.0")
  def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
    fit(dataset.rdd.map(_.asScala))
  }
}


/**
* Word2Vec model
* @param wordIndex maps each word to an index, which can retrieve the corresponding
*                  vector from wordVectors
* @param wordVectors array of length numWords * vectorSize, vector corresponding
*                    to the word mapped with index i can be retrieved by the slice
*                    (i * vectorSize, i * vectorSize + vectorSize)
*/
@Since("1.1.0")
class Word2VecModel private[spark] (
    private[spark] val wordIndex: Map[String, Int],
    private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable {


  private val numWords = wordIndex.size
  // vectorSize: Dimension of each word's vector.
  private val vectorSize = wordVectors.length / numWords


  // wordList: Ordered list of words obtained from wordIndex.
  private val wordList: Array[String] = {
    val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
    wl.toArray
  }
  // wordVecNorms: Array of length numWords, each value being the Euclidean norm
  //               of the wordVector.   长度为numWords的数组,每个值都是wordVector的欧几里得范数。
  private val wordVecNorms: Array[Float] = {
    val wordVecNorms = new Array[Float](numWords)
    var i = 0
    while (i < numWords) {
      val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
      wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)  
      i += 1
    }
    wordVecNorms
  }


  @Since("1.5.0")
  def this(model: Map[String, Array[Float]]) = {
    this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
  }


  override protected def formatVersion = "1.0"


  @Since("1.4.0")
  def save(sc: SparkContext, path: String): Unit = {
    Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
  }


  /**
   * Transforms a word to its vector representation
   * @param word a word
   * @return vector representation of word
   */
  @Since("1.1.0")
  def transform(word: String): Vector = {
    wordIndex.get(word) match {
      case Some(ind) =>
        val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
        Vectors.dense(vec.map(_.toDouble))
      case None =>
        throw new IllegalStateException(s"$word not in vocabulary")
    }
  }


  /**
   * Find synonyms of a word; do not include the word itself in results.
   * @param word a word
   * @param num number of synonyms to find
   * @return array of (word, cosineSimilarity)
   */
  @Since("1.1.0")
  def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
    val vector = transform(word)
    findSynonyms(vector, num, Some(word))
  }


  /**
   * Find synonyms of the vector representation of a word, possibly
   * including any words in the model vocabulary whose vector respresentation
   * is the supplied vector.
   * @param vector vector representation of a word
   * @param num number of synonyms to find
   * @return array of (word, cosineSimilarity)
   */
  @Since("1.1.0")
  def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
    findSynonyms(vector, num, None)
  }


  /**
   * Find synonyms of the vector representation of a word, rejecting
   * words identical to the value of wordOpt, if one is supplied.
   * @param vector vector representation of a word
   * @param num number of synonyms to find
   * @param wordOpt optionally, a word to reject from the results list
   * @return array of (word, cosineSimilarity)
   */
  private def findSynonyms(
      vector: Vector,   //需要找的这个词向量在所有词向量里面的相似结果
      num: Int,         //需要找的TOPN
      wordOpt: Option[String]): Array[(String, Double)] = {  //返回形式是 (词,相似度)
    require(num > 0, "Number of similar words should > 0")


    val fVector = vector.toArray.map(_.toFloat)  //由double类型变为Float类型,可以节省存储空间
    val cosineVec = new Array[Float](numWords)   //vector 与每个词向量直接的cosine相似度值
    val alpha: Float = 1
    val beta: Float = 0
    // Normalize input vector before blas.sgemv to avoid Inf value  这样对归一化后的结果避免了无穷大的异常出现
    val vecNorm = blas.snrm2(vectorSize, fVector, 1)  //blas.snrm2函数:SNRM2 := sqrt( x'*x ).通过函数名称返回向量的欧几里得范数
    if (vecNorm != 0.0f) {
      blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1)  //blas.sscal函数:scales a vector by a constant, uses unrolled loops for increment equal to 1,sscal(int n, float sa, float[] sx, int _sx_offset, int incx) 结果是向量所元素sx*sa,其中 sx[_sx_offset,_sx_offset+incx*n]
    }
    blas.sgemv(
      "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) //wordVectors * fVector   //wordVectors 500万*100维  fVector 100*1维  cosineVec 500万*1维
//sgemv(java.lang.String trans, int m, int n, float alpha, float[] a, int lda, float[] x, int incx, float beta, float[] y, int incy)  y := alpha*A*x + beta*y,   or   y := alpha*A'*x + beta*y
    var i = 0
    while (i < numWords) {
      val norm = wordVecNorms(i)   //每个词向量的欧几里得范数
      if (norm == 0.0f) {
        cosineVec(i) = 0.0f
      } else {
        cosineVec(i) /= norm  //之前fVector已经除了vecNorm,后面只需各个词向量除以自己的范式就行了
      }
      i += 1
    }

//堆排序取数
    val pq = new BoundedPriorityQueue[(String, Float)](num + 1)(Ordering.by(_._2))


    var j = 0
    while (j < numWords) {
      pq += Tuple2(wordList(j), cosineVec(j))
      j += 1
    }


    val scored = pq.toSeq.sortBy(-_._2)


    val filtered = wordOpt match {
      case Some(w) => scored.filter(tup => w != tup._1)
      case None => scored
    }


    filtered
      .take(num)
      .map { case (word, score) => (word, score.toDouble) }
      .toArray
  }


  /**
   * Returns a map of words to their vector representations.
   */
  @Since("1.2.0")
  def getVectors: Map[String, Array[Float]] = {
    wordIndex.map { case (word, ind) =>
      (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))
    }
  }


}


@Since("1.4.0")
object Word2VecModel extends Loader[Word2VecModel] {


  private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
    model.keys.zipWithIndex.toMap
  }


  private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
    require(model.nonEmpty, "Word2VecMap should be non-empty")
    val (vectorSize, numWords) = (model.head._2.length, model.size)
    val wordList = model.keys.toArray
    val wordVectors = new Array[Float](vectorSize * numWords)
    var i = 0
    while (i < numWords) {
      Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
      i += 1
    }
    wordVectors
  }


  private object SaveLoadV1_0 {


    val formatVersionV1_0 = "1.0"


    val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"


    case class Data(word: String, vector: Array[Float])


    def load(sc: SparkContext, path: String): Word2VecModel = {
      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
      val dataFrame = spark.read.parquet(Loader.dataPath(path))
      // Check schema explicitly since erasure makes it hard to use match-case for checking.
      Loader.checkSchema[Data](dataFrame.schema)


      val dataArray = dataFrame.select("word", "vector").collect()
      val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
      new Word2VecModel(word2VecMap)
    }


    def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = {
      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()


      val vectorSize = model.values.head.length
      val numWords = model.size
      val metadata = compact(render(
        ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
        ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))


      // We want to partition the model in partitions smaller than
      // spark.kryoserializer.buffer.max
      val bufferSize = Utils.byteStringAsBytes(
        spark.conf.get("spark.kryoserializer.buffer.max", "64m"))
      // We calculate the approximate size of the model
      // We only calculate the array size, considering an
      // average string size of 15 bytes, the formula is:
      // (floatSize * vectorSize + 15) * numWords
      val approxSize = (4L * vectorSize + 15) * numWords
      val nPartitions = ((approxSize / bufferSize) + 1).toInt
      val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
      spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))
    }
  }


  @Since("1.4.0")
  override def load(sc: SparkContext, path: String): Word2VecModel = {


    val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
    implicit val formats = DefaultFormats
    val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
    val expectedNumWords = (metadata \ "numWords").extract[Int]
    val classNameV1_0 = SaveLoadV1_0.classNameV1_0
    (loadedClassName, loadedVersion) match {
      case (classNameV1_0, "1.0") =>
        val model = SaveLoadV1_0.load(sc, path)
        val vectorSize = model.getVectors.values.head.length
        val numWords = model.getVectors.size
        require(expectedVectorSize == vectorSize,
          s"Word2VecModel requires each word to be mapped to a vector of size " +
          s"$expectedVectorSize, got vector of size $vectorSize")
        require(expectedNumWords == numWords,
          s"Word2VecModel requires $expectedNumWords words, but got $numWords")
        model
      case _ => throw new Exception(
        s"Word2VecModel.load did not recognize model with (className, format version):" +
        s"($loadedClassName, $loadedVersion).  Supported:\n" +
        s"  ($classNameV1_0, 1.0)")
    }
  }
}

word2vec 的原理 只需要看层次哈弗曼树skip-gram那部分

原理部分推荐链接:https://www.cnblogs.com/shixiangwan/p/7808249.html
其中Sparkword2vec使用过程中有以下问题:

  1. 当迭代次数或者分区过多的情况下,会产生Infinity的问题
  2. 训练过程中分区过多准确度会下降
  3. 内存消耗过大,全部cache了
  4. 哈夫曼树的方法时间消耗大。等问题,这些问题最近几天完善都一一解决了嘿嘿

问题1:
针对第一个问题解决方法:
基于以上的源码可以看见:
spark实现skip-gram直接复现原始word2vec-C语言版本。
spark实现得到词向量是累加所有分区和所有迭代的结果,随着迭代次数的增大和分区数增加导致词向量数值异常。采用归一化词向量迭代结果,把每次迭代和每个分区的结果累加并且归一化就可以了。
具体代码如下:
把这个代码

val synAgg = partial.reduceByKey { case (v1, v2) =>
        blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
        v1

改为:

//修改思路就是把所有的向量结果取均值,计算每个向量的在所有分区出现的计数,然后再取均值
      // https://github.com/apache/spark/pull/26722
      // SPARK-24666: do normalization for aggregating weights from partitions.
      // Original Word2Vec either single-thread or multi-thread which do Hogwild-style aggregation.
      // Our approach needs to do extra normalization, otherwise adding weights continuously may
      // cause overflow on float and lead to infinity/-infinity weights.
       val synAgg = partial.mapPartitions { iter =>
        iter.map { case (id, vec) =>
          (id, (vec, 1))
        }
      }.reduceByKey { case ((v1, count1), (v2, count2)) =>
        blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
        (v1, count1 + count2)
      }.map { case (id, (vec, count)) =>
        blas.sscal(vectorSize, 1.0f / count, vec, 1)
        (id, vec)

就可以了。完美解决向量值大的问题。
问题2:
分区即是分治的思想,把数据mapPartitions一下每个分区维护自己的一套参数,后期处理把每个分区的参数累加处理,所以在数据迭代上只是并行迭代累加并未串行按照样本依次迭代。分区过多导致每个分区的数据量过小会减少准确度,但是word2vec的效果跟分词质量和数据量的大小有这很大关系。
可以把里面的bcVocabHash.destroy(blocking = false) // 用完了 需要进行释放 销毁,提前释放

skip-gram negetive sample 的版本源码解析:

参考该代码:github SKNS的实现

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.feature.sgns

// https://github.com/shubhamchopra/spark/tree/Word2VecSGNS/mllib/src/main/scala/org/apache/spark/ml/feature

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.feature
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom

object Word2VecCBOWSolver extends Logging {
  // learning rate is updated for every batch of size batchSize
  private val batchSize = 10000

  // power to raise the unigram distribution with
  private val power = 0.75

//  private val EXP_TABLE_SIZE = 1000
  private val EXP_TABLE_SIZE = 10000
  private val MAX_EXP = 6

  case class Vocabulary(
                         totalWordCount: Long,
                         vocabMap: Map[String, Int],
                         unigramTable: Array[Int],
                         samplingTable: Array[Float])

  /**
    * This method implements Word2Vec Continuous Bag Of Words based implementation using
    * negative sampling optimization, using BLAS for vectorizing operations where applicable.
    * The algorithm is parallelized in the same way as the skip-gram based estimation.
    * We divide input data into N equally sized random partitions.
    * We then generate initial weights and broadcast them to the N partitions. This way
    * all the partitions start with the same initial weights. We then run N independent
    * estimations that each estimate a model on a partition. The weights learned
    * from each of the N models are averaged and rebroadcast the weights.
    * This process is repeated `maxIter` number of times.
    *
    * @param input A RDD of strings. Each string would be considered a sentence.
    * @return Estimated word2vec model
    */
  def fit[S <: Iterable[String]](
//                                spark:SparkSession, //TODO
                                  word2Vec: Word2Vec,
                                  skipGramMode: Boolean,
                                  input: RDD[S]): feature.Word2VecModel = {

    val negativeSamples = word2Vec.getNegativeSamples //负采样的词数目
    val sample = word2Vec.getSample //针对高频词的衰减系数
    /**
      * totalWordCount   Long类型的实数
      * vocabMap         Map[String, Int]长度为所有词的大小
      * unigramTable     Array[Int]   长度为设置的采样table的长度
      * samplingTable    Array[Float] 所有词的长度大小,采样表
      *
      */
    val Vocabulary(totalWordCount, vocabMap, uniTable, sampleTable) =
      generateVocab(input, word2Vec.getMinCount, sample, word2Vec.getUnigramTableSize)
    val vocabSize = vocabMap.size

    //TODO 确认负采样的词的数目要小于整个词汇的大小
    assert(negativeSamples < vocabSize, s"Vocab size ($vocabSize) cannot be smaller" +
      s" than negative samples($negativeSamples)")

    val seed = word2Vec.getSeed
    val initRandom = new XORShiftRandom(seed)

    val vectorSize = word2Vec.getVectorSize
    val syn0Global = Array.fill(vocabSize * vectorSize)(initRandom.nextFloat - 0.5f)  //随机初始化的向量作为输入
    val syn1Global = Array.fill(vocabSize * vectorSize)(0.0f)

    val sc = input.context

    //以下是广播的数据量
    val vocabMapBroadcast = sc.broadcast(vocabMap)
    val unigramTableBroadcast = sc.broadcast(uniTable)
    val sampleTableBroadcast = sc.broadcast(sampleTable)
    val expTable = sc.broadcast(createExpTable())

    val windowSize = word2Vec.getWindowSize  //滑动窗口大小
    val maxSentenceLength = word2Vec.getMaxSentenceLength //最大的句子长度
    val numPartitions = word2Vec.getNumPartitions //运行时的分区数目

/*    import spark.implicits._
    val xxx = input.map(x => x.toSeq).toDF("seq")

    sc.parallelize(vocabMap.map(x => (x._1,x._2)).toArray,100).toDF("skuid","index")*/


    //就是把每个句子里的的word转换为index Int格式
    val digitSentences = input.flatMap { sentence =>
      val wordIndexes = sentence.flatMap(vocabMapBroadcast.value.get) // 针对每个句子里面的词,得到该词的index,在flatMap里面应用这个函数[vocabMapBroadcast.value.get],得到每个
      //grouped按照maxSentenceLength分组,把wordIndexes按照最大长度分成几个部分,每个部分的长度不超过 maxSentenceLength
      wordIndexes.grouped(maxSentenceLength).map(_.toArray)
    }.repartition(numPartitions)  //指定分区数目
      .cache() //TODO cache 可以修改存储方式,减少内存
    digitSentences.sparkContext.setCheckpointDir("hdfs://ns4/user/dd_edw/tmp.db/item_relationship/item_embedding/graph_embedding_rdwalk")
    digitSentences.checkpoint()
    digitSentences.count //TODO action
    vocabMapBroadcast.destroy()  //TODO 用完了直接销毁

    val learningRate = word2Vec.getStepSize  //学习率设置的是0.025D

    val wordsPerPartition = totalWordCount / numPartitions   //每个partitions的数据量(以总的词数目不去重为准)

    logInfo(s"VocabSize: ${vocabMap.size}, TotalWordCount: $totalWordCount")

    val maxIter = word2Vec.getMaxIter
    for {iteration <- 1 to maxIter} {   //迭代次数,每次迭代里面有若干个分组batch运行,注意是在每个partition里面的。
      logInfo(s"Starting iteration: $iteration")
      val iterationStartTime = System.nanoTime()

      val syn0bc = sc.broadcast(syn0Global)   //广播词参数
      val syn1bc = sc.broadcast(syn1Global)

      val partialFits = digitSentences.mapPartitionsWithIndex { case (i_, iter) =>
        logInfo(s"Iteration: $iteration, Partition: $i_")
        val random = new XORShiftRandom(seed ^ ((i_ + 1) << 16) ^ ((-iteration - 1) << 8))
        val contextWordPairs = iter.flatMap { s => //iter为一个分区里的所有句子迭代器  s为遍历一个分区里的句子
          val doSample = sample > Double.MinPositiveValue   //是否有采样系数 boolean类型
          /**
            *得到的是Iterator[(Seq[Int], Int)],其中Seq[Int]是后面Int的上下窗口词集合,后者Int是中心词index
            */
          generateContextWordPairs(
            s,
            windowSize,
            doSample,
            sampleTableBroadcast.value,
            random)
        }
        //把所有的中心词对应的窗口词集合分组批次,按照batchSize大小  batchSize大小默认是10000
        val groupedBatches = contextWordPairs.grouped(batchSize)
        //负采样标签,negativeSamples负采样词的个数
        val negLabels = 1.0f +: Array.fill(negativeSamples)(0.0f)
        val syn0 = syn0bc.value
        val syn1 = syn1bc.value
        val unigramTable = unigramTableBroadcast.value //长度大小为负采样表的长度,负采样的table,目前设置的长度为2千万

        // initialize intermediate arrays
        val contextVec = new Array[Float](vectorSize)
        val l2Vectors = new Array[Float](vectorSize * (negativeSamples + 1))
        val gb = new Array[Float](negativeSamples + 1)
        val neu1e = new Array[Float](vectorSize)
        val wordIndices = new Array[Int](negativeSamples + 1)

        val time = System.nanoTime
        var batchTime = System.nanoTime
        var idx = -1L
        for (batch <- groupedBatches) { // 一个batch就是Seq[(Seq[Int], Int)]集合
          idx = idx + 1 //每迭代一个batch就会idx增加1

          val wordRatio = //会随着idx和iteration的增大而增大,越往后迭代得到的wordRatio越小
            idx.toFloat * batchSize / (maxIter * (wordsPerPartition.toFloat + 1)) +
              ((iteration - 1).toFloat / maxIter)

            //学习率会随着wordRatio的增大而减小,但是不会小于设置的learningRate * 0.0001,越往后迭代学习率会越小
          val alpha = math.max(learningRate * 0.0001, learningRate * (1 - wordRatio)).toFloat

          if(idx % 10 == 0 && idx > 0) { //TODO 每迭代10个batch会做一次汇总,对当前运行过得batch的时间进行统计。
            logInfo(s"Partition: $i_, wordRatio = $wordRatio, alpha = $alpha") //打印各个分区index,学习率等数据
            val wordCount = batchSize * idx   //本分区总共已经迭代的中心词个数
            val timeTaken = (System.nanoTime - time) / 1e6  // 对每个分区定时器,计算到这步所用的时间
            val batchWordCount = 10 * batchSize  //
            val currentBatchTime = (System.nanoTime - batchTime) / 1e6  // 对每个分区定时器,计算到这步所用的时间
            batchTime = System.nanoTime
            logDebug(s"Partition: $i_, Batch time: $currentBatchTime ms, batch speed: " +
              s"${batchWordCount / currentBatchTime * 1000} words/s")
            logDebug(s"Partition: $i_, Cumulative time: $timeTaken ms, cumulative speed: " +
              s"${wordCount / timeTaken * 1000} words/s")
          }

          val errors = for ((ids, word) <- batch) yield {  //遍历每个batch里面的内容
            val contexts = if (skipGramMode) { //如果是sg-ns模型
              ids.map(i => Seq(i))  //把每个上下窗口的词变成Seq集合
            } else {
              Seq(ids)  //Seq[Seq[int]]
            }

            val errs = for (contextIds <- contexts) yield {
              // initialize vectors to 0
              zeroVector(contextVec)
              zeroVector(l2Vectors)
              zeroVector(gb)
              zeroVector(neu1e)

              val scale = 1.0f / contextIds.length  //上下SKU个数 如果是sg-ns的话,一个词就是一个Seq,如果是CNOW-ns的话上下SKU集合是一个Seq

              // feed forward   前馈
              contextIds.foreach { c =>
                //blas.saxpy函数解释:saxpy(int n, float sa, float[] sx, int _sx_offset, int incx, float[] sy, int _sy_offset, int incy),
                // 结果是:sy= sa*sx+sy,并且sx[_sx_offset , _sx_offset + incx*n],sy[ _sy_offset , _sy_offset + incy*n],
                // 本语句的意思是:contextVec = scale*syn0 + contextVec 其中的位移范围是一个vectorSize的长度
                blas.saxpy(vectorSize, scale, syn0, c * vectorSize, 1, contextVec, 0, 1)
              }

              //word是指当前中心词,针对每个中心词采样negativeSamples个词
              generateNegativeSamples(random, word, unigramTable, negativeSamples, wordIndices)

              Iterator.range(0, wordIndices.length).foreach { i =>
                // copy(src: AnyRef, srcPos: Int, dest: AnyRef, destPos: Int, length: Int)
                // 把syn1复制到l2Vectors  l2Vectors是负采样词向量表
                Array.copy(syn1, vectorSize * wordIndices(i), l2Vectors, vectorSize * i, vectorSize)
              }

              // propagating hidden to output in batch  传播隐藏层到output层
              val rows = negativeSamples + 1
              val cols = vectorSize
              //sgemv(trans: String, m: Int, n: Int, alpha: Float, a: Array[Float], _a_offset: Int, lda: Int, x: Array[Float],
              // _x_offset: Int, incx: Int, beta: Float, y: Array[Float], _y_offset: Int, incy: Int)
              // y := alpha*A*x + beta*y,   or   y := alpha*A'*x + beta*y  |||||||====|||||||  gb = 1.0f * l2Vectors * contextVec + 0.0f * gb
              blas.sgemv("T", cols, rows, 1.0f, l2Vectors, 0, cols, contextVec, 0, 1, 0.0f, gb, 0, 1)

              Iterator.range(0, negativeSamples + 1).foreach { i =>
                if (gb(i) > -MAX_EXP && gb(i) < MAX_EXP) {

                  val ind = ((gb(i) + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt  //TODO 修改sigmod函数的调用采用查表方式
                  val v = expTable.value(ind)

//                  val v = 1.0f / (1 + math.exp(-gb(i)).toFloat)  //sigmod计算,其实可以进行查表运算
                  // computing error gradient
                  val err = (negLabels(i) - v) * alpha  //梯度
                  // update hidden -> output layer, syn1
                  // syn1 = err * contextVec + syn1
                  blas.saxpy(vectorSize, err, contextVec, 0, 1, syn1, wordIndices(i) * vectorSize, 1)
                  // update for word vectors
                  // neu1e = err * l2Vectors + neu1e
                  blas.saxpy(vectorSize, err, l2Vectors, i * vectorSize, 1, neu1e, 0, 1)
                  gb.update(i, err)
                } else {
                  gb.update(i, 0.0f)
                }
              }

              // update input -> hidden layer, syn0
              contextIds.foreach { i =>
                // syn0 = 1.0f * neu1e + syn0
                blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, i * vectorSize, 1)
              }
              gb.map(math.abs).sum / alpha
            }
            errs.sum  //每个中心词迭代后的误差
          }
          logInfo(s"Partition: $i_, Average Batch Error = ${errors.sum / batchSize}")
        }
        Iterator.tabulate(vocabSize) { index =>
          (index, syn0.slice(index * vectorSize, (index + 1) * vectorSize))
        } ++ Iterator.tabulate(vocabSize) { index =>
          (vocabSize + index, syn1.slice(index * vectorSize, (index + 1) * vectorSize))
        }
      }

      val aggedMatrices = partialFits.reduceByKey { case (v1, v2) =>
        blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
        v1

      /*val aggedMatrices =  partialFits.mapPartitions { iter =>
        iter.map { case (id, vec) =>
          (id, (vec, 1))
        }
      }.reduceByKey { case ((v1, count1), (v2, count2)) =>
        blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
        (v1, count1 + count2)
      }.map { case (id, (vec, count)) =>
        blas.sscal(vectorSize, 1.0f / count, vec, 1)
        (id, vec)*/

      }.collect()

      val norm = 1.0f / numPartitions
      aggedMatrices.foreach {case (index, v) =>
        blas.sscal(v.length, norm, v, 0, 1)
        if (index < vocabSize) {
          Array.copy(v, 0, syn0Global, index * vectorSize, vectorSize)
        } else {
          Array.copy(v, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
        }
      }

      syn0bc.destroy(false)
      syn1bc.destroy(false)
      val timePerIteration = (System.nanoTime() - iterationStartTime) / 1e6
      logInfo(s"Total time taken per iteration: ${timePerIteration} ms")
    }
    digitSentences.unpersist()
//    vocabMapBroadcast.destroy()
    unigramTableBroadcast.destroy()
    sampleTableBroadcast.destroy()

    new feature.Word2VecModel(vocabMap, syn0Global)
  }

  /**
    * Similar to InitUnigramTable in the original code.   跟源码一样的操作
    */
  private def generateUnigramTable(normalizedWeights: Array[Double], tableSize: Int): Array[Int] = {
    val table = new Array[Int](tableSize)
    var index = 0
    var wordId = 0
    while (index < table.length) { //遍历table
      table.update(index, wordId)
      //[index.toFloat / table.length]这个值最大值是1,normalizedWeights数组最大值也是1,强制把normalizedWeights分成tableSize个区间,
      // 按照table索引的刻度进行分割,最终得到的table里面的每个元素是相邻的是wordID正是table一个刻度所包含的word
      if (index.toFloat / table.length >= normalizedWeights(wordId)) {
        wordId = math.min(normalizedWeights.length - 1, wordId + 1)
      }
      index += 1
    }
    table
  }

  private def createExpTable(): Array[Float] = {
    val expTable = new Array[Float](EXP_TABLE_SIZE)
    var i = 0
    while (i < EXP_TABLE_SIZE) {
      val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP)
      expTable(i) = (tmp / (tmp + 1.0)).toFloat
      i += 1
    }
    expTable
  }

  /**
    *
    * @param input
    * @param minCount  最小的频次限制
    * @param sample   sample的系数设置
    * @param unigramTableSize  采样的表格大小
    * @tparam S
    * @return
    */
  private def generateVocab[S <: Iterable[String]](
                                                    input: RDD[S],
                                                    minCount: Int,
                                                    sample: Double,
                                                    unigramTableSize: Int): Vocabulary = {
    val sc = input.context

    val words = input.flatMap(x => x)

    //按照词的频次进行排序,并且zip上index
    val sortedWordCounts = words.map(w => (w, 1L))
      .reduceByKey(_ + _)
      .filter{case (w, c) => c >= minCount}
      .collect()
      .sortWith{case ((w1, c1), (w2, c2)) => c1 > c2}
      .zipWithIndex

    val totalWordCount = sortedWordCounts.map(_._1._2).sum   //所有词的总和

    //每个词的对应的index索引
    val vocabMap = sortedWordCounts.map{case ((w, c), i) =>
      w -> i
    }.toMap

    //所有词大小的 抽样表
    val samplingTable = new Array[Float](vocabMap.size)

    if (sample > Double.MinPositiveValue) {  // sample 大于double最小值
      sortedWordCounts.foreach { case ((w, c), i) =>
        val samplingRatio = sample * totalWordCount / c  //采样的概率
        samplingTable.update(i, (math.sqrt(samplingRatio) + samplingRatio).toFloat)
      }
    }

    val weights = sortedWordCounts.map{ case((_, x), _) => scala.math.pow(x, power)} //对每个词的频次进行 f^0.75次方
    val totalWeight = weights.sum   //所有的权重和

    //scanLeft:扫描,即对某个集合的所有元素做fold操作,但是会把产生的所有中间结果放置于一个集合中保存 ,跟foldLeft还是有区别的,foldLeft不存储中间结果
    //TODO normalizedCumWeights数组的长度为所有词的个数大小
    val normalizedCumWeights = weights.scanLeft(0.0)(_ + _).tail.map(x => x / totalWeight) //数组的tail操作,除了头全是尾部
    //Unigram table size. The unigram table is used to generate negative samples.    本程序设置的2千万
    val unigramTable = generateUnigramTable(normalizedCumWeights, unigramTableSize)

    /**
      * totalWordCount   Long类型的实数
      * vocabMap         Map[String, Int]长度为所有词的大小
      * unigramTable     Array[Int]   长度为设置的采样table的长度
      * samplingTable    Array[Float] 所有词的长度大小,采样表
      *
      */
    Vocabulary(totalWordCount, vocabMap, unigramTable, samplingTable)
  }

  private def zeroVector(v: Array[Float]): Unit = {
    var i = 0
    while(i < v.length) {
      v.update(i, 0.0f)
      i+= 1
    }
  }

  /**
    *
    * @param sentence
    * @param window
    * @param doSample
    * @param samplingTable
    * @param random
    * @return   生成中心词的上下文 词对
    */
  private def generateContextWordPairs(
                                        sentence: Array[Int],
                                        window: Int,
                                        doSample: Boolean,
                                        samplingTable: Array[Float],
                                        random: XORShiftRandom): Iterator[(Seq[Int], Int)] = {
    val reducedSentence = if (doSample) {
      sentence.filter(i => samplingTable(i) > random.nextFloat)  //每个句子里,随机选取一些词
    } else {
      sentence
    }
    val sentenceLength = reducedSentence.length

    Iterator.range(0, sentenceLength)//该句子的长度
      .map { i =>
      val b = window - random.nextInt(window) // (window - a) in original code
    // pick b words around the current word index
    val start = math.max(0, i - b) // c in original code, floor ar 0
    val end = math.min(sentenceLength, i + b + 1) // cap at sentence length
    // make sure current word is not a part of the context
    val contextIds = Iterator.range(start, end).filter(_ != i).map(reducedSentence(_)) //得到start,end范围内的词的索引
      val word = reducedSentence(i)
      (contextIds.toSeq, word)
    }
  }

  /**
    *
    * @param random
    * @param word
    * @param unigramTable
    * @param numSamples
    * @param arr  最终返回的是arr,里面第一个是word中心词,后面依次是numSamples个采样词
    */
  // This essentially helps translate from uniform distribution to a distribution
  // resembling uni-gram frequency distribution.
  private def generateNegativeSamples(
                                       random: XORShiftRandom,
                                       word: Int,
                                       unigramTable: Array[Int],
                                       numSamples: Int,
                                       arr: Array[Int]): Unit = {
    assert(numSamples + 1 == arr.length,
      s"Input array should be large enough to hold ${numSamples} negative samples")
    arr.update(0, word)  //arr的第一个元素是word(中心词Word)
    var i = 1
    while (i <= numSamples) { //迭代随机选取样本
      val negSample = unigramTable(random.nextInt(unigramTable.length))
      if(negSample != word) {
        arr.update(i, negSample)
        i += 1
      }
    }
  }
}

你可能感兴趣的:(数据挖掘)