Spark MLlib特征处理:OneHotEncoder OneHot编码 ---原理及实战

原理

1)String字符串转换成索引IndexDouble

2)索引转换成SparseVector

总结:OneHotEncoder = String > IndexDouble > SparseVector

代码实战

import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}

object OneHotEncoderExample {
  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[8]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    // 将Seq集合转换成DataFrame
    // Seq是一个有先后次序的序列(也可以叫集合),Vector Range List Array都属于Seq类型
    val df: DataFrame = sqlContext.createDataFrame(Seq(
      (0, "a"),
      (1, "b"),
      (2, "c"),
      (3, "a"),
      (4, "a"),
      (5, "c")
    )).toDF("id", "category")

    // String => IndexDouble
    val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
    val indexed = indexer.fit(df).transform(df)

    // IndexDouble => SparseVector
    // OneHotEncode:实际上是转换成了稀疏向量
    // Spark源码: The last category is not included by default 最后一个种类默认不包含
    // 和python scikit-learn's OneHotEncoder不同,scikit-learn's OneHotEncoder包含所有
    val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec")
      // 设置最后一个是否包含
        .setDropLast(false)
    //transform 转换成稀疏向量
    val encoded = encoder.transform(indexed)
    encoded.select("category","categoryIndex", "categoryVec").show()
    sc.stop()
  }

}
// 输出
// +--------+-------------+-------------+
// |category|categoryIndex|  categoryVec|
// +--------+-------------+-------------+
// |       a|          0.0|(3,[0],[1.0])|
// |       b|          2.0|(3,[2],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       a|          0.0|(3,[0],[1.0])|
// |       c|          1.0|(3,[1],[1.0])|
// +--------+-------------+-------------+

你可能感兴趣的:(机器学习)