自定义开发Spark ML机器学习类 - 1

初窥门径

Spark的MLlib组件内置实现了很多常见的机器学习算法,包括数据抽取,分类,聚类,关联分析,协同过滤等等.
然鹅,内置的算法并不能满足我们所有的需求,所以我们还是经常需要自定义ML算法.

MLlib提供的API分为两类:
- 1.基于DataFrame的API,属于spark.ml包.
- 2.基于RDD的API, 属于spark.mllib包.

从Spark 2.0开始,Spark的API全面从RDD转向DataFrame,MLlib也是如此,官网原话如下:

Announcement: DataFrame-based API is primary API

The MLlib RDD-based API is now in maintenance mode.

所以本文将介绍基于DataFrame的自定义ml类编写方法.不涉及具体算法,只讲扩展ml类的方法.

略知一二

官方文档并没有介绍如何自定义ml类,所以只有从源码入手,看看源码里面是怎么实现的.

找一个最简单的内置算法入手,这个算法就是内置的分词器,Tokenizer.

Tokenizer只是简单的将文本以空白部分进行分割,只适合给英文进行分词,所以它的实现及其简短,源码如下:

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

/**
 * A tokenizer that converts the input string to lowercase and then splits it by white spaces.
 *
 * @see [[RegexTokenizer]]
 */
@Since("1.2.0")
class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
  extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {

  @Since("1.2.0")
  def this() = this(Identifiable.randomUID("tok"))

  override protected def createTransformFunc: String => Seq[String] = {
    _.toLowerCase.split("\\s")
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == StringType, s"Input type must be string type but got $inputType.")
  }

  override protected def outputDataType: DataType = new ArrayType(StringType, true)

  @Since("1.4.1")
  override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}

@Since("1.6.0")
object Tokenizer extends DefaultParamsReadable[Tokenizer] {

  @Since("1.6.0")
  override def load(path: String): Tokenizer = super.load(path)
}

简单分析下源码:
- Tokenizer继承了UnaryTransformer类.unary是’一元’的意思,也是说这个类实现的是类似一元函数的功能,一个输入变量,一个输出.直接看UnaryTransformer的源码注释:

/**
* :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/

DeveloperApi表明这是一个开发级API,开发者可以用,不会有权限问题(源码中有很多private[spark]的类,是不允许外部调用的).
注释的大意就是:这是一个为实现transformers准备的抽象类,以一个字段(列)为输入,输出一个新字段(列).

所以实际上就是实现一个Transformer,只是这个Transformer有指定的输入字段和输出字段.

  • UnaryTransformer类中只有两个抽象方法.
    一个是createTransformFunc,是最核心的方法,这个方法需要返回一个函数,这个函数的参数即Transformer的输入字段的值,返回值为Transformer的输出字段的值.看看Tokenizer中的实现,就明白了.

另一个是outputDataType,这个方法用来返回输出字段的类型.

  • validateInputType方法是用来检查输入字段类型的,看需要实现.

  • Tokenizer混入了DefaultParamsWritable特质,使得自己可以被保存.
    对应的object Tokenizer伴生对象,用来读取已保存的Tokenizer.

  • 值得注意的是,Transformer类是PipelineStage类的子类,所以Transformer的子类,包括我们自定义的,是可以直接用在ML Pipelines中的.这就厉害了,说明自定义的算法类,可以无缝与内置机器学习算法打配合,还能利用Pipeline的调优工具(model selection,Cross-Validation等).

初出茅庐

看完源码,基本套路已经明了,不如动手抄一个,不,敲一个.
依葫芦画瓢,实现一个正则提取的Transformer.

import util.matching.Regex

import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._

/**
  * 正则提取器
  * 将匹配指定正则表达式的全部子字符串,提取到array[string]中.
  */
class RegexExtractor(override val uid: String)
  extends UnaryTransformer[String, Seq[String], RegexExtractor] {

  def this() = this(Identifiable.randomUID("RegexExtractor"))

  /**
    * 参数:正则表达式
    *
    * @group param
    */
  final val regex = new Param[Regex](this, "RegexExpr", "正则表达式")

  /** @group setParam */
  def setRegexExpr(value: String): this.type = set(regex, new Regex(value))

  override protected def outputDataType: DataType = new ArrayType(StringType, true)

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType == DataTypes.StringType,
      s"Input type must be string type but got $inputType."
    )
  }

  override protected def createTransformFunc: String => Seq[String] = {
    parseContent
  }

  /**
    * 数据处理
    */
  private def parseContent(text: String): Seq[String] = {
    if (text == null || text.isEmpty) {
      return Seq.empty[String]
    }
    $(regex).findAllIn(text).toSeq
  }

}

这个类结构与Tokenizer源码基本差不多,多用到的Param类,是一个参数的包装类.
作用是self-contained documentation and optionally default value.
其实就是把参数的值,文档,默认值等属性组合成一个类,方便调用.

比如上面定义的regex参数,就可以用$(regex)这样的方式直接调用.

另外在org.apache.spark.ml.param中有很多内置的Param类,可以直接使用.

同时org.apache.spark.ml.param.shared中有很多辅助引入参数的特质,比如HasInputCols特质,你的自定义Transformer只要混入这个特质就拥有了inputCols参数.不过目前shared中特质的作用域是private[ml],也就是说不能直接引用,而是要copy一份代码到自己的项目,并修改作用域才行.
关于这个作用域的问题,有人在spark的jira上提到,提议将其作为DeveloperApi开放出来,我也投了一票表示支持.后来在2017年11月终于resolved,该问题将在Spark2.3.0中解决.详情戳我

粗懂皮毛

自定义的类写好了,该怎么用呢? 当然是跟内置的一样啦.上栗子:

val regex="nidezhengze"

val tranTitle = new RegexExtractor()
    .setInputCol("title")
    .setOutputCol("title_price_texts")
    .setRegexExpr(regex)

val pipeline = new Pipeline().setStages(Array(
    tranTitle
))

val matched = pipeline.fit(data).transform(data)

打完收功

到这里,开发简单Transform的套路已经清楚了,不过这里实现的功能比较类似于一个UDF,只能对dataset的一个字段进行处理,而且是逐行处理,并不能根据多行数据进行处理,实现窗口函数类似的功能,而且也没有涉及模型的输出.如果要开发更复杂的算法,甚至进行模型训练,就需要更深入的了解MLlib了,阅读源码是个好途径.

下回再说.

你可能感兴趣的:(Spark)