本文使用scala语言,基于spark2+
由于没有将类写到包org.apache.spark.ml.feature
里,所以很多spark源码里的方法不可以直接调用。如spark2.3以下就不可以直接继承sharedParmas
里面的特质。
由于Estimator
类是PipelineStage
的子类,所以可以运用spark2.0的管道操作。
import org.apache.spark.ml.util._
import org.apache.spark.ml.param._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.Model
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkException
trait WOEBase extends Params {
//从spark2.3开始,可以直接继承sharedParmas里面的相关特质
final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name")
final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name")
final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name")
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")
def getLabelCol() = $(labelCol)
def getInputCol() = $(inputCol)
def getOutputCol() = $(outputCol)
def getInputCols() = $(inputCols)
def getOutputCols() = $(outputCols)
final val delta: DoubleParam = new DoubleParam(this, "delta", "防止出现0值,造成除0溢出或对数无穷大,而增加的修正值")
def getDelta() = $(delta)
setDefault(delta -> 1, labelCol -> "label") //Params类的方法,设置参数默认值
//此方法将
protected def getInOutCols: (Array[String], Array[String]) = {
//require方法是scala.Predef对象下的预定义方法,判断条件,条件为false则抛出IllegalArgumentException异常
require(
//这里用isSet检查参数是否被set方法设置过,默认的参数(通过setDefault设置的参数)并不会返回True,而是False,
//保证了我们可以给任意列设置默认参数,而此句并不需要修改
(isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
(!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
"WOE only supports setting either inputCol/outputCol or" +
"inputCols/outputCols.")
if (isSet(inputCol)) { //isSet:Params类的方法,检查是否设置了参数值
(Array($(inputCol)), Array($(outputCol)))
} else {
require(
$(inputCols).length == $(outputCols).length,
"inputCols number do not match outputCols")
($(inputCols), $(outputCols))
}
}
protected def validateAndTransformSchemas(schema: StructType): StructType = {
//StructField包含字段名称、类型(例如StringType,IntegerType,ArrayType等)、能否为空、和metadata信息
//StructType包含了多个StructField,一个schema就是一个StructType
//Dataset支持的数据类型都在org.apache.spark.sql.types包下面,大多都是DataType的子类
val labelColName = $(labelCol)
val labelDataType = schema(labelColName).dataType
require(
labelDataType.isInstanceOf[NumericType],
s"The label column $labelColName must be numeric type, " +
s"but got $labelDataType.")
val (inputColNames, outputColNames) = getInOutCols
val existingFields = schema.fields
var outputFields = existingFields
inputColNames.zip(outputColNames).foreach {
case (inputColName, outputColName) =>
require(
existingFields.exists(_.name == inputColName),
s"Iutput column ${inputColName} not exists.")
require(
existingFields.forall(_.name != outputColName),
s"Output column ${outputColName} already exists.")
val attr = NominalAttribute.defaultAttr.withName(outputColName)
outputFields :+= attr.toStructField()
}
StructType(outputFields)
}
}
class WOE(override val uid: String)
extends Estimator[WOEModel]
with WOEBase with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("WOE")) //WOE_ 和一个随机数组成的标识符作为uid
//set方法
def setLabelCol(value: String): this.type = set(labelCol, value)
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
def setDelta(value: Double): this.type = set(delta, value)
override def copy(extra: ParamMap): this.type = defaultCopy(extra) //必须要实现的方法,调用默认defaultCopy方法即可
override def fit(dataset: Dataset[_]): WOEModel = { //必须要实现的方法,主要实现逻辑
transformSchema(dataset.schema, true) //PipelineStage类的方法,调用本类实现的transformSchema方法,另外布尔值参数决定是否将转换前后的schema信息用logDebug输出
val delta_value = $(delta) //防止出现0值,而增加的修正
val T = dataset.count
// val B = dataset.agg(sum("y")).first.getLong(0)
val B = dataset.where($(labelCol) + " = 1").count()
val G = T - B
val woe_map_arr = new ArrayBuffer[Map[String, Double]]()
val (inputColNames, outputColNames) = getInOutCols
inputColNames.foreach {
inputColName =>
val gDs_t = dataset.groupBy(inputColName).agg(count($(labelCol)).as("T"), sum($(labelCol)).as("B"))
val gDs = gDs_t.withColumn("G", gDs_t("T") - gDs_t("B"))
val loger = udf { d: Double =>
math.log(d)
}
val woe_map = gDs.withColumn("woe", loger((gDs("B") + delta_value) / (B + delta_value) * (G + delta_value) / (gDs("G") + delta_value)))
.select(col(inputColName).cast(StringType), col("woe"))
.collect()
.map(r => (r.getString(0), r.getDouble(1)))
.toMap
woe_map_arr += woe_map
}
copyValues(new WOEModel(uid, woe_map_arr.toSeq).setParent(this)) //copyValues:Params特质的方法,将parent的参数值拷贝给model(如果model有一样的参数)
}
override def transformSchema(schema: StructType): StructType = { //必须要实现的方法,输出转换后的schema;这个方法如果不做任何事,fit里不掉用应该也可以,未测试
validateAndTransformSchemas(schema)
}
}
//save方法调用的就是write.save
//load方法调用的是read.load方法
object WOE extends DefaultParamsReadable[WOE] {
override def load(path: String): WOE = super.load(path) //必须要实现的方法,直接用DefaultParamsReadable的,实际上是调用了DefaultParamsReader的load方法
}
//Estimator学习输出Transformer实际上就是传递一个数据结构。
//fit方法会将这个学到的数据结果作为传给Transformer:直接作为构造参数传递或者用设置参数的形式传递都可以,
//这里采用构造参数传递。一般会用参数方法传递:就可以作为参数获取或者设置模型的规则了。
//而且这里简化逻辑没有区分单列转换还是多列转换(inputCol还是inputCols):单列和多列都当作多列来处理。同理,也要重写classs WOE的write方法
class WOEModel(override val uid: String, val woe_map_arr: Seq[Map[String, Double]])
extends Model[WOEModel]
with MLWritable with WOEBase {
def this(woe_map_arr: Seq[Map[String, Double]]) = this(Identifiable.randomUID("WOE"), woe_map_arr)
def setInputCol(value: String): this.type = set(inputCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
// def setDelta(value: Double): this.type = set(delta, value) //模型不可以设置delta,因为delta只对学习有用,对转换没用
override def copy(extra: ParamMap): WOEModel = {
val copied = new WOEModel(uid, woe_map_arr)
copyValues(copied, extra).setParent(parent) //copyValues方法能够拷贝参数
}
import WOEModel._
override def write: WOEModelWriter = new WOEModelWriter(this)
override def transform(dataset: Dataset[_]): DataFrame = {
val (inputColNames, outputColNames) = getInOutCols
transformSchema(dataset.schema)
require(
woe_map_arr.length == inputColNames.length,
s"The number of input columns is not equal to the number of WOEModel model maps ")
var df: DataFrame = dataset.toDF()
woe_map_arr.zipWithIndex.map {
case (woe_map, idx) =>
val inputColName = inputColNames(idx)
val outputColName = outputColNames(idx)
val woer = udf { (feature: String) =>
woe_map.get(feature) match {
case Some(n: Double) => n
case None =>
//这里选择直接抛出异常,之前用return dataset会报错
throw new SparkException(s"Input column_${inputColName}'s value ${feature} does not exist in the WOEModel model map. " +
"Skip WOEModel.")
}
}//.asNondeterministic() //spark 2.3支持此句
df = df.withColumn(outputColName, woer(dataset(inputColName).cast(StringType)))
}
df
}
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchemas(schema)
}
}
object WOEModel extends MLReadable[WOEModel] {
import org.apache.hadoop.fs.Path
import org.json4s.JsonDSL._
//render,compact方法都是这里面的
import org.json4s.jackson.JsonMethods._
import org.json4s.JsonAST._
implicit val format = org.json4s.DefaultFormats
//这里是自己实现了保存细节,spark源码部分有统一的实现,但是是private[ml]的。
private[WOEModel] class WOEModelWriter(instance: WOEModel) extends MLWriter {
private case class Data(woe_map_arr: Seq[Map[String, Double]])
override protected def saveImpl(path: String): Unit = {
//这里选择了对所有参数进行保存,因为我们的outputCol并没有设置默认值,所以没问题
//如果outputCol有默认值,并且设置了inputCols和outputCols参数,保存的时候就要去掉outputCol的默认参数保存:
//因为,一旦将其默认值也保存,再加载的时候会用set方法设置参数,而不是setDefault,然后调用transform的时候会检查独占参数会报错
//详情参考SPARK-23377
val metadataPath = new Path(path, "metadata").toString
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = render(
params.map {
case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList)
val basicMetadata = ("class" -> instance.getClass.getName) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> instance.uid) ~
("paramMap" -> jsonParams)
val metadataJson = compact(render(basicMetadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
val data = Data(instance.woe_map_arr)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class WOEModelReader extends MLReader[WOEModel] {
private val className = classOf[WOEModel].getName
override def load(path: String): WOEModel = {
val metadataPath = new Path(path, "metadata").toString
val s = sc.textFile(metadataPath, 1).first()
val metadata = parse(s)
val clz = (metadata \ "class").extract[String]
val uid = (metadata \ "uid").extract[String]
require(className == clz, s"Error loading metadata: Expected class name" +
s" className but found class name ${clz}")
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("woe_map_arr")
.head()
val woe_map_arr = data.getAs[Seq[Map[String, Double]]](0)
val instance = new WOEModel(uid, woe_map_arr)
val params = metadata \ "paramMap"
params match {
case JObject(pairs) =>
pairs.foreach {
case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${s}.")
}
instance
}
}
override def read: MLReader[WOEModel] = new WOEModelReader
override def load(path: String): WOEModel = super.load(path)
}