Spark的Word2Vec示例

Spark的Word2Vec示例

这里演示Spark的ml包里面Word2Vec的训练与加载过程

import org.apache.spark.ml.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

/**
  * 从句子训练得到Embedding
  */
object TextEmbedding {
  val embeddingSize = 3

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TextEmbedding")
      .master("local[2]")
      .getOrCreate()

    val df = loadText(spark)
    val model = trianByWord2Vec(spark, df)
    saveModel(spark, model, args(0))
    saveTextEmb(spark, model, df, args(1))
    saveWordEmb(spark, model, args(2))

    loadModel(spark, args(0))

  }

  // 示例
  def loadText(spark: SparkSession): DataFrame = {
    val df = spark.createDataFrame(Seq(
      (0, Array("Hi", "I", "heard", "about", "Spark")),
      (1, Array("I", "wish", "Java", "could", "use", "case", "classes")),
      (2, Array("Logistic", "regression", "models", "are", "neat"))
    )).toDF("id", "words")

    println("df count:" + df.count())
    df.show(10, false)
    df
  }

  // 训练过程
  def trianByWord2Vec(spark: SparkSession, df: DataFrame): Word2VecModel = {
    val wordDataFrame = df
    val word2Vec = new Word2Vec()
      .setInputCol("words")
      .setOutputCol("result")
      .setVectorSize(embeddingSize)
      .setMinCount(3)
      .setWindowSize(3)
      .setMaxIter(10)

    val model = word2Vec.fit(wordDataFrame)

    model.getVectors.show(100, false)
    //    +----------+------------------------------------------------------------------+
    //    |word      |vector                                                            |
    //    +----------+------------------------------------------------------------------+
    //    |heard     |[-0.053989291191101074,0.14687322080135345,-0.0022512583527714014]|
    //    +----------+------------------------------------------------------------------+

    model

  }

  // 保存w2v模型
  def saveModel(spark: SparkSession, model: Word2VecModel, path:String): Unit = {
    println("saving model:" + path)
    model.write.mode(SaveMode.Overwrite).save(path)
  }

  // 保存文案embedding
  def saveTextEmb(spark: SparkSession, model: Word2VecModel, df: DataFrame, path:String): Unit = {
    println(s"saving $path")
    val result = model.transform(df)
    result.printSchema()
    result.show(false)

    result.select("id", "result")
      .repartition(5)
      .write
      .option("sep", "\t")
      .mode(SaveMode.Overwrite)
      .parquet(path)

  }

  // 保存句子embedding
  def saveWordEmb(spark: SparkSession, model: Word2VecModel, path:String): Unit = {
    println(s"saving $path")
    val result = model.getVectors
    result.show(false)
    result.repartition(5).write.option("sep", "\t").mode(SaveMode.Overwrite).parquet(path)
  }

  // 模型加载与解析
  def loadModel(spark:SparkSession, path:String):Word2VecModel = {
    val model = Word2VecModel.load(path)

    model.getVectors.show(false)
    //    +----------+------------------------------------------------------------------+
    //    |word      |vector                                                            |
    //    +----------+------------------------------------------------------------------+
    //    |heard     |[-0.053989291191101074,0.14687322080135345,-0.0022512583527714014]|
    //    |are       |[-0.16293057799339294,-0.14514029026031494,0.1139335036277771]    |
    //    |neat      |[-0.0406828410923481,0.028049567714333534,-0.16289857029914856]   |
    //    |classes   |[-0.1490514725446701,-0.04974571615457535,0.03320947289466858]    |
    //    |I         |[-0.019095497205853462,-0.131216898560524,0.14303986728191376]    |
    //    |regression|[0.16541987657546997,0.06469681113958359,0.09233078360557556]     |
    //    |Logistic  |[0.036407098174095154,0.05800342187285423,-0.021965932101011276]  |
    //    |Spark     |[-0.1267719864845276,0.09859133511781693,-0.10378564894199371]    |
    //    |could     |[0.15352481603622437,0.06008218228816986,0.07726015895605087]     |
    //    |use       |[0.08318991959095001,0.002120430115610361,-0.07926633954048157]   |
    //    |Hi        |[-0.05663909390568733,0.009638422168791294,-0.033786069601774216] |
    //    |models    |[0.11912573128938675,0.1333899050951004,0.1441687047481537]       |
    //    |case      |[0.14080166816711426,0.08094961196184158,0.1596144139766693]      |
    //    |about     |[0.11579915136098862,0.10381520539522171,-0.06980287283658981]    |
    //    |Java      |[0.12235434353351593,-0.03189820423722267,-0.1423865109682083]    |
    //    |wish      |[0.14934538304805756,-0.11263544857501984,-0.03990427032113075]   |
    //    +----------+------------------------------------------------------------------+

    println("model.getVectors.printSchema()")
    model.getVectors.printSchema()

    val embMap = model.getVectors
      .collect()
      .map(row => {
        val value = row.getAs[Vector](1)
        val key = row.getAs[String](0)
        (key, value)
      }).toMap

    // 相似度实例
    model.findSynonyms("I", 2).show(false)
    //    +-------+-------------------+
    //    |word   |similarity         |
    //    +-------+-------------------+
    //    |are    |0.800910234451294  |
    //    |classes|0.45088085532188416|
    //    +-------+-------------------+

    model

  }

}

你可能感兴趣的:(机器学习,数据挖掘,Spark)