scala做embedding的average操作

scala做embedding的average操作

使用 breeze.linalg 来对embedding向量处理

breeze.linalg 库可以对矩阵向量做很多操作,普通的加减乘除,点乘叉乘,都能支持

import breeze.linalg.DenseVector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

/**
  * scala做embedding的average操作
  */
object EmbeddingAvg {
  val embeddingSize = 3

  def avgAction(spark: SparkSession, emb: DataFrame, history: DataFrame): DataFrame = {
    import spark.implicits._
    val defaultEmbSize = embeddingSize
    println("defaultEmbSize:" + defaultEmbSize)
    val embMap = emb.collect()
      .map(row => {
        val key = row.getAs[String]("id")
        val value = row.getAs[Vector]("vector")
        (key, value)
      }).toMap

    val result = history.map(row => {
      val key = row.getAs[String]("key")
      val list = row.getAs[String]("list")
        .split(",")
        .map(x => {
          if (embMap.contains(x)) {
            embMap(x)
          } else {
            Vectors.zeros(defaultEmbSize)
          }
        })

      val head = list.head
      var res = new DenseVector[Double](head.toArray)
      list.tail.foreach(v => {
        val ve = new DenseVector[Double](v.toArray)
        res = res + ve
      })

	// avg 操作
      res = res *:* (1.0 / list.length)

      (key, res.toArray.mkString(","))
    }).toDF("key", "emb")
      .persist(StorageLevel.MEMORY_AND_DISK_SER)

    println("result count:"  + result.count())
    result.show(false)
    result
  }

}

breeze.linalg 库可参考网上 https://www.cnblogs.com/itboys/p/10594039.html

你可能感兴趣的:(Spark,scala,机器学习,spark,scala,embedding)