Spark的MLlib组件内置实现了很多常见的机器学习算法,包括数据抽取,分类,聚类,关联分析,协同过滤等等.
然鹅,内置的算法并不能满足我们所有的需求,所以我们还是经常需要自定义ML算法.
MLlib提供的API分为两类:
- 1.基于DataFrame的API,属于spark.ml包.
- 2.基于RDD的API, 属于spark.mllib包.
从Spark 2.0开始,Spark的API全面从RDD转向DataFrame,MLlib也是如此,官网原话如下:
Announcement: DataFrame-based API is primary API
The MLlib RDD-based API is now in maintenance mode.
所以本文将介绍基于DataFrame的自定义ml类编写方法.不涉及具体算法,只讲扩展ml类的方法.
官方文档并没有介绍如何自定义ml类,所以只有从源码入手,看看源码里面是怎么实现的.
找一个最简单的内置算法入手,这个算法就是内置的分词器,Tokenizer.
Tokenizer只是简单的将文本以空白部分进行分割,只适合给英文进行分词,所以它的实现及其简短,源码如下:
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*
* @see [[RegexTokenizer]]
*/
@Since("1.2.0")
class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("tok"))
override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
override protected def outputDataType: DataType = new ArrayType(StringType, true)
@Since("1.4.1")
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}
@Since("1.6.0")
object Tokenizer extends DefaultParamsReadable[Tokenizer] {
@Since("1.6.0")
override def load(path: String): Tokenizer = super.load(path)
}
简单分析下源码:
- Tokenizer继承了UnaryTransformer类.unary是’一元’的意思,也是说这个类实现的是类似一元函数的功能,一个输入变量,一个输出.直接看UnaryTransformer的源码注释:
/**
* :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
DeveloperApi表明这是一个开发级API,开发者可以用,不会有权限问题(源码中有很多private[spark]的类,是不允许外部调用的).
注释的大意就是:这是一个为实现transformers准备的抽象类,以一个字段(列)为输入,输出一个新字段(列).
所以实际上就是实现一个Transformer,只是这个Transformer有指定的输入字段和输出字段.
另一个是outputDataType,这个方法用来返回输出字段的类型.
validateInputType方法是用来检查输入字段类型的,看需要实现.
Tokenizer混入了DefaultParamsWritable特质,使得自己可以被保存.
对应的object Tokenizer伴生对象,用来读取已保存的Tokenizer.
值得注意的是,Transformer类是PipelineStage类的子类,所以Transformer的子类,包括我们自定义的,是可以直接用在ML Pipelines中的.这就厉害了,说明自定义的算法类,可以无缝与内置机器学习算法打配合,还能利用Pipeline的调优工具(model selection,Cross-Validation等).
看完源码,基本套路已经明了,不如动手抄一个,不,敲一个.
依葫芦画瓢,实现一个正则提取的Transformer.
import util.matching.Regex
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._
/**
* 正则提取器
* 将匹配指定正则表达式的全部子字符串,提取到array[string]中.
*/
class RegexExtractor(override val uid: String)
extends UnaryTransformer[String, Seq[String], RegexExtractor] {
def this() = this(Identifiable.randomUID("RegexExtractor"))
/**
* 参数:正则表达式
*
* @group param
*/
final val regex = new Param[Regex](this, "RegexExpr", "正则表达式")
/** @group setParam */
def setRegexExpr(value: String): this.type = set(regex, new Regex(value))
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == DataTypes.StringType,
s"Input type must be string type but got $inputType."
)
}
override protected def createTransformFunc: String => Seq[String] = {
parseContent
}
/**
* 数据处理
*/
private def parseContent(text: String): Seq[String] = {
if (text == null || text.isEmpty) {
return Seq.empty[String]
}
$(regex).findAllIn(text).toSeq
}
}
这个类结构与Tokenizer源码基本差不多,多用到的Param类,是一个参数的包装类.
作用是self-contained documentation and optionally default value.
其实就是把参数的值,文档,默认值等属性组合成一个类,方便调用.
比如上面定义的regex参数,就可以用$(regex)这样的方式直接调用.
另外在org.apache.spark.ml.param中有很多内置的Param类,可以直接使用.
同时org.apache.spark.ml.param.shared中有很多辅助引入参数的特质,比如HasInputCols特质,你的自定义Transformer只要混入这个特质就拥有了inputCols参数.不过目前shared中特质的作用域是private[ml],也就是说不能直接引用,而是要copy一份代码到自己的项目,并修改作用域才行.
关于这个作用域的问题,有人在spark的jira上提到,提议将其作为DeveloperApi开放出来,我也投了一票表示支持.后来在2017年11月终于resolved,该问题将在Spark2.3.0中解决.详情戳我
自定义的类写好了,该怎么用呢? 当然是跟内置的一样啦.上栗子:
val regex="nidezhengze"
val tranTitle = new RegexExtractor()
.setInputCol("title")
.setOutputCol("title_price_texts")
.setRegexExpr(regex)
val pipeline = new Pipeline().setStages(Array(
tranTitle
))
val matched = pipeline.fit(data).transform(data)
到这里,开发简单Transform的套路已经清楚了,不过这里实现的功能比较类似于一个UDF,只能对dataset的一个字段进行处理,而且是逐行处理,并不能根据多行数据进行处理,实现窗口函数类似的功能,而且也没有涉及模型的输出.如果要开发更复杂的算法,甚至进行模型训练,就需要更深入的了解MLlib了,阅读源码是个好途径.
下回再说.