scala--标签和索引的转化:StringIndexer- IndexToString-VectorIndexer

来源:http://mocom.xmu.edu.cn/article/show/587f11deaa2c3f280956e7ba/0/1

版权声明: 本文发自http://mocom.xmu.edu.cn,为 赖永炫 老师的个人博文,文章仅代表个人观点。无需授权即可转载,转载时请务必注明作者。



Spark的机器学习处理过程中,经常需要把标签数据(一般是字符串)转化成整数索引,而在计算结束又需要把整数索引还原为标签。这就涉及到几个转换器:StringIndexer、 IndexToString,OneHotEncoder,以及针对类别特征的索引VectorIndexer。

StringIndexer

​ StringIndexer是指把一组字符型标签编码成一组标签索引,索引的范围为0到标签数量,索引构建的顺序为标签的频率,优先编码频率较大的标签,所以出现频率最高的标签为0号。如果输入的是数值型的,我们会把它转化成字符型,然后再对其进行编码。在pipeline组件,比如Estimator和Transformer中,想要用到字符串索引的标签的话,我们一般需要通过setInputCol来设置输入列。另外,有的时候我们通过一个数据集构建了一个StringIndexer,然后准备把它应用到另一个数据集上的时候,会遇到新数据集中有一些没有在前一个数据集中出现的标签,这时候一般有两种策略来处理:第一种是抛出一个异常(默认情况下),第二种是通过掉用 setHandleInvalid("skip")来彻底忽略包含这类标签的行。

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.StringIndexer

scala> val sqlContext = new SQLContext(sc)
sqlContext: org.apache.spark.sql.SQLContext = org.apache.spark.sql.SQLContext@2869d920

scala> import sqlContext.implicits._
import sqlContext.implicits._

scala> val df1 = sqlContext.createDataFrame(
     |       Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
     |     ).toDF("id", "category")
df1: org.apache.spark.sql.DataFrame = [id: int, category: string]

scala> val indexer = new StringIndexer().
     |       setInputCol("category").
     |       setOutputCol("categoryIndex")
indexer: org.apache.spark.ml.feature.StringIndexer = strIdx_95a0a5afdb8b

scala> val indexed1 = indexer.fit(df1).transform(df1)
indexed1: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double]

scala> indexed1.show()
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+

scala> val df2 = sqlContext.createDataFrame(
     |       Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "d"))
     |     ).toDF("id", "category")
df2: org.apache.spark.sql.DataFrame = [id: int, category: string]

scala> val indexed2 = indexer.fit(df1).setHandleInvalid("skip").transform(df2)
indexed2: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double]

scala> indexed2.show()
+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
+---+--------+-------------+

scala> val indexed3 = indexer.fit(df1)transform(df2)
indexed3: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double]

scala> indexed3.show()
org.apache.spark.SparkException: Unseen label: d.

​ 在上例当中,我们首先构建了1个dataframe,然后设置了StringIndexer的输入列和输出列的名字。通过indexed1.show(),我们可以看到,StringIndexer依次按照出现频率的高低,把字符标签进行了排序,即出现最多的“a”被编号成0,“c”为1,出现最少的“b”为0。接下来,我们构建了一个新的dataframe,这个dataframe中有一个再上一个dataframe中未曾出现的标签“d”,然后我们通过设置setHandleInvalid("skip")来忽略标签“d”的行,结果通过indexed2.show()可以看到,含有标签“d”的行并没有出现。如果,我们没有设置的话,则会抛出异常,报出“Unseen label: d”的错误。

IndexToString

​ 对称的,IndexToString的作用是把标签索引的一列重新映射回原有的字符型标签。一般都是和StringIndexer配合,先用StringIndexer转化成标签索引,进行模型训练,然后在预测标签的时候再把标签索引转化成原有的字符标签。当然,也允许你使用自己提供的标签。

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.{StringIndexer, IndexToString}

scala> val sqlContext = new SQLContext(sc)
sqlContext: org.apache.spark.sql.SQLContext = org.apache.spark.sql.SQLContext@2869d920

scala> import sqlContext.implicits._
import sqlContext.implicits._

scala> val df = sqlContext.createDataFrame(Seq(
     |       (0, "a"),
     |       (1, "b"),
     |       (2, "c"),
     |       (3, "a"),
     |       (4, "a"),
     |       (5, "c")
     |     )).toDF("id", "category")
df: org.apache.spark.sql.DataFrame = [id: int, category: string]

scala> val indexer = new StringIndexer().
     |       setInputCol("category").
     |       setOutputCol("categoryIndex").
     |       fit(df)
indexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_00fde0fe64d0

scala> val indexed = indexer.transform(df)
indexed: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double]

scala> val converter = new IndexToString().
     |       setInputCol("categoryIndex").
     |       setOutputCol("originalCategory")
converter: org.apache.spark.ml.feature.IndexToString = idxToStr_b95208a0e7ac

scala> val converted = converter.transform(indexed)
converted: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double, originalCategory: string]

scala> converted.select("id", "originalCategory").show()
+---+----------------+
| id|originalCategory|
+---+----------------+
|  0|               a|
|  1|               b|
|  2|               c|
|  3|               a|
|  4|               a|
|  5|               c|
+---+----------------+

​ 在上例中,我们首先用StringIndexer读取数据集中的“category”列,把字符型标签转化成标签索引,然后输出到“categoryIndex”列上。然后再用IndexToString读取“categoryIndex”上的标签索引,获得原有数据集的字符型标签,然后再输出到“originalCategory”列上。最后,通过输出“originalCategory”列,可以看到数据集中原有的字符标签。

OneHotEncoder

​ 独热编码是指把一列标签索引映射成一列二进制数组,且最多的时候只有一位有效。这种编码适合一些期望类别特征为连续特征的算法,比如说逻辑斯蒂回归。

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}

scala> val sqlContext = new SQLContext(sc)
sqlContext: org.apache.spark.sql.SQLContext = org.apache.spark.sql.SQLContext@2869d920

scala> import sqlContext.implicits._
import sqlContext.implicits._

scala> val df = sqlContext.createDataFrame(Seq(
     |       (0, "a"),
     |       (1, "b"),
     |       (2, "c"),
     |       (3, "a"),
     |       (4, "a"),
     |       (5, "c"),
     |       (6, "d"),
     |       (7, "d"),
     |       (8, "d"),
     |       (9, "d"),
     |       (10, "e"),
     |       (11, "e"),
     |       (12, "e"),
     |       (13, "e"),
     |       (14, "e")
     |     )).toDF("id", "category")
df: org.apache.spark.sql.DataFrame = [id: int, category: string]

scala> val indexer = new StringIndexer().
     |       setInputCol("category").
     |       setOutputCol("categoryIndex").
     |       fit(df)
indexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_b315cf21d22d

scala> val indexed = indexer.transform(df)
indexed: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double]

scala> val encoder = new OneHotEncoder().
     |       setInputCol("categoryIndex").
     |       setOutputCol("categoryVec")
encoder: org.apache.spark.ml.feature.OneHotEncoder = oneHot_bbf16821b33a

scala> val encoded = encoder.transform(indexed)
encoded: org.apache.spark.sql.DataFrame = [id: int, category: string, categoryIndex: double, categoryVec: vector]

scala> encoded.show()
+---+--------+-------------+-------------+
| id|category|categoryIndex|  categoryVec|
+---+--------+-------------+-------------+
|  0|       a|          2.0|(4,[2],[1.0])|
|  1|       b|          4.0|    (4,[],[])|
|  2|       c|          3.0|(4,[3],[1.0])|
|  3|       a|          2.0|(4,[2],[1.0])|
|  4|       a|          2.0|(4,[2],[1.0])|
|  5|       c|          3.0|(4,[3],[1.0])|
|  6|       d|          1.0|(4,[1],[1.0])|
|  7|       d|          1.0|(4,[1],[1.0])|
|  8|       d|          1.0|(4,[1],[1.0])|
|  9|       d|          1.0|(4,[1],[1.0])|
| 10|       e|          0.0|(4,[0],[1.0])|
| 11|       e|          0.0|(4,[0],[1.0])|
| 12|       e|          0.0|(4,[0],[1.0])|
| 13|       e|          0.0|(4,[0],[1.0])|
| 14|       e|          0.0|(4,[0],[1.0])|
+---+--------+-------------+-------------+

​ 在上例中,我们构建了一个dataframe,包含“a”,“b”,“c”,“d”,“e” 五个标签,通过调用OneHotEncoder,我们发现出现频率最高的标签“e”被编码成第0位为1,即第0位有效,出现频率第二高的标签“d”被编码成第1位有效,依次类推,“a”和“c”也被相继编码,出现频率最小的标签“b”被编码成全0。

VectorIndexer

    VectorIndexer解决向量数据集中的类别特征索引。它可以自动识别哪些特征是类别型的,并且将原始值转换为类别索引。它的处理流程如下:

​ 1.获得一个向量类型的输入以及maxCategories参数。

​ 2.基于不同特征值的数量来识别哪些特征需要被类别化,其中最多maxCategories个特征需要被类别化。

​ 3.对于每一个类别特征计算0-based(从0开始)类别索引。

​ 4.对类别特征进行索引然后将原始特征值转换为索引。

   索引后的类别特征可以帮助决策树等算法恰当的处理类别型特征,并得到较好结果。

   在下面的例子中,我们读入一个数据集,然后使用VectorIndexer来决定哪些特征需要被作为类别特征,将类别特征转换为他们的索引。
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.mllib.linalg.{Vector, Vectors}

scala> val sqlContext = new SQLContext(sc)
sqlContext: org.apache.spark.sql.SQLContext = org.apache.spark.sql.SQLContext@2869d920

scala> import sqlContext.implicits._
import sqlContext.implicits._

scala> val data = Seq(Vectors.dense(-1.0, 1.0, 1.0),Vectors.dense(-1.0, 3.0, 1.0), Vectors.dense(0.0, 5.0, 1.0))
data: Seq[org.apache.spark.mllib.linalg.Vector] = List([-1.0,1.0,1.0], [-1.0,3.0,1.0], [0.0,5.0,1.0])

scala> val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
df: org.apache.spark.sql.DataFrame = [features: vector]

scala> val indexer = new VectorIndexer().
     |       setInputCol("features").
     |       setOutputCol("indexed").
     |       setMaxCategories(2)
indexer: org.apache.spark.ml.feature.VectorIndexer = vecIdx_abee81bafba8

scala> val indexerModel = indexer.fit(df)
indexerModel: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_abee81bafba8

scala> val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet
categoricalFeatures: Set[Int] = Set(0, 2)

scala> println(s"Chose ${categoricalFeatures.size} categorical features: " + categoricalFeatures.mkString(", "))
Chose 2 categorical features: 0, 2

scala> val indexedData = indexerModel.transform(df)
indexedData: org.apache.spark.sql.DataFrame = [features: vector, indexed: vector]

scala> indexedData.foreach { println }
[[-1.0,1.0,1.0],[1.0,1.0,0.0]]
[[-1.0,3.0,1.0],[1.0,3.0,0.0]]
[[0.0,5.0,1.0],[0.0,5.0,0.0]]

​ 从上例可以看到,我们设置maxCategories为2,即只有种类小于2的特征才被认为是类别型特征,否则被认为是连续型特征。其中类别型特征将被进行编号索引,为了索引的稳定性,规定如果这个特征值为0,则一定会被编号成0,这样可以保证向量的稀疏度(未来还会再维持索引的稳定性上做更多的工作,比如如果某个特征类别化后只有一个特征,则会进行警告等等,这里就不过多介绍了)。于是,我们可以看到第0类和第2类的特征由于种类数不超过2,被划分成类别型特征,并进行了索引,且为0的特征值也被编号成了0号。


你可能感兴趣的:(技术层-scala)