spark ml对机器学习算法的api进行了标准化,使将多个算法合并到一个管道或工作流变得更容易。为了更清楚了解,从以下及几个方面展开说明。
DataFrame:这个ML API使用Spark SQL的DataFrame作为ML数据集,它可以容纳各种数据类型。例如,DataFrame可能有不同的列存储文本、特征向量、真实标签和预测。
Transformer: Transformer是一种可以将一个DataFrame转换成另一个DataFrame的算法。例如,ML模型是一个转换器,它将具有特性的DataFrame转换为具有预测的DataFrame。
Estimator:估计器是一种算法,用于DataFrame转换。例如,学习算法是一种估计器,它训练一个DataFrame并生成一个模型。
pipeline:管道将多个变压器和估计器链接在一起,以指定一个ML工作流。
尝试用spark ml实现广告点击预测,训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:https://www.kaggle.com/c/avazu-ctr-prediction/data。
开发环境:java1.8.0_172+scala2.11.8+spark2.3.1
依赖包
org.apache.spark
spark-core_2.11
2.3.1
org.apache.spark
spark-sql_2.11
2.3.1
org.apache.spark
spark-hive_2.11
2.3.1
org.apache.spark
spark-mllib_2.11
2.3.1
spark加载csv文件,dataframe基本结构如下:
val data = spark.read.csv("/opt/data/ads_6M.csv").toDF(
"id","click","hour","C1","banner_pos","site_id","site_domain",
"site_category","app_id","app_domain","app_category","device_id","device_ip",
"device_model","device_type","device_conn_type","C14","C15","C16","C17","C18",
"C19","C20","C21")
data.show(5,false)
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+
|id |click|hour |C1 |banner_pos|site_id |site_domain|site_category|app_id |app_domain|app_category|device_id|device_ip|device_model|device_type|device_conn_type|C14 |C15|C16|C17 |C18|C19|C20 |C21|
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+
|10153523536315735769|0 |14102100|1005|0 |85f751fd|c4e18dd6 |50e219e0 |53de0284|d9b5648e |0f2161f8 |a99f214a |788c3e75 |2ea4f8ba |1 |0 |20508|320|50 |2351|3 |163|-1 |61 |
|10448041871517116234|0 |14102100|1005|0 |1fbe01fe|f3845767 |28905ebd |ecad2386|7801e8d9 |07d7df22 |a99f214a |99cd8fa2 |81b42528 |1 |0 |15707|320|50 |1722|0 |35 |-1 |79 |
|10488488220071431784|0 |14102100|1005|1 |72a56356|45368af7 |3e814130 |ecad2386|7801e8d9 |07d7df22 |a99f214a |e8fc2f9f |900981af |1 |2 |18993|320|50 |2161|0 |35 |-1 |157|
|10625948582770087788|0 |14102100|1005|0 |85f751fd|c4e18dd6 |50e219e0 |5e3f096f|2347f47a |0f2161f8 |a99f214a |9c1b8be7 |24f6b932 |1 |0 |18993|320|50 |2161|0 |35 |100215|157|
|11151072182888929242|0 |14102100|1005|1 |5b4d2eda|16a36ef3 |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |866e0a54 |d787e91b |1 |0 |16208|320|50 |1800|3 |167|-1 |23 |
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+
包含24个字段:
其中5到15列为分类特征,16~24列为数值型特征。将数据集分为训练集和测试集,比例为0.7:0.3。
val splited = data.randomSplit(Array(0.7,0.3),2L)
对于分类特征可以使用StringIndexer将标签的字符串列编码为标签索引列,将字符串特征转化为数值特征,便于下游管道组件处理。
val catalog_features = Array("click","site_id","site_domain","site_category","app_id","app_domain","app_category","device_id","device_ip","device_model")
var train_index = splited(0)
var test_index = splited(1)
for(catalog_feature <- catalog_features){
val indexer = new StringIndexer()
.setInputCol(catalog_feature)
.setOutputCol(catalog_feature.concat("_index"))
val train_index_model = indexer.fit(train_index)
val train_indexed = train_index_model.transform(train_index)
val test_indexed = indexer.fit(test_index).transform(test_index,train_index_model.extractParamMap())
train_index = train_indexed
test_index = test_indexed
}
println("字符串编码下标标签:")
train_index.show(5,false)
test_index.show(5,false)
字符串编码下标标签:
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|id |click|hour |C1 |banner_pos|site_id |site_domain|site_category|app_id |app_domain|app_category|device_id|device_ip|device_model|device_type|device_conn_type|C14 |C15|C16|C17 |C18|C19 |C20 |C21|click_index|site_id_index|site_domain_index|site_category_index|app_id_index|app_domain_index|app_category_index|device_id_index|device_ip_index|device_model_index|
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|10000133892746881176|0 |14102813|1005|0 |85f751fd|c4e18dd6 |50e219e0 |febd1138|82e27996 |0f2161f8 |a99f214a |f5c62586 |b4b19c97 |1 |0 |21611|320|50 |2480|3 |297 |100111|61 |0.0 |0.0 |0.0 |0.0 |4.0 |4.0 |1.0 |0.0 |23751.0 |64.0 |
|10000987464039884177|0 |14102816|1005|0 |5bcf81a2|9d54950b |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |845f69f4 |fa61e8fe |1 |0 |23438|320|50 |2684|2 |1319|-1 |52 |0.0 |11.0 |7.0 |1.0 |0.0 |0.0 |0.0 |0.0 |5237.0 |67.0 |
|10001055656394300907|0 |14102814|1005|0 |85f751fd|c4e18dd6 |50e219e0 |e9739828|df32afa9 |cef3e649 |a99f214a |6454c6ba |ecb851b2 |1 |0 |23441|320|50 |2685|1 |33 |100083|212|0.0 |0.0 |0.0 |0.0 |13.0 |11.0 |2.0 |0.0 |18147.0 |8.0 |
|10001237608243220141|0 |14102701|1005|0 |85f751fd|c4e18dd6 |50e219e0 |febd1138|82e27996 |0f2161f8 |a99f214a |ab986e15 |2ea4f8ba |1 |0 |19743|320|50 |2264|3 |427 |100000|61 |0.0 |0.0 |0.0 |0.0 |4.0 |4.0 |1.0 |0.0 |23941.0 |34.0 |
|10001363001408225332|0 |14102812|1005|1 |85f751fd|c4e18dd6 |50e219e0 |1dc72b4d|2347f47a |0f2161f8 |b7c2e4b6 |bce45090 |5db079b5 |1 |2 |22998|300|50 |2657|3 |35 |100013|23 |0.0 |0.0 |0.0 |0.0 |25.0 |1.0 |1.0 |1760.0 |729.0 |25.0 |
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
only showing top 5 rows
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|id |click|hour |C1 |banner_pos|site_id |site_domain|site_category|app_id |app_domain|app_category|device_id|device_ip|device_model|device_type|device_conn_type|C14 |C15|C16|C17 |C18|C19|C20 |C21|click_index|site_id_index|site_domain_index|site_category_index|app_id_index|app_domain_index|app_category_index|device_id_index|device_ip_index|device_model_index|
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|10002333262420133303|0 |14102211|1005|1 |856e6d3f|58a89a43 |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |ac322dfb |0dc22ebc |1 |0 |19771|320|50 |2227|0 |679|100077|48 |0.0 |6.0 |6.0 |1.0 |0.0 |0.0 |0.0 |0.0 |6004.0 |279.0 |
|10002749335348787004|1 |14102800|1005|0 |2a68aa20|9b851bd8 |3e814130 |ecad2386|7801e8d9 |07d7df22 |a99f214a |b4a0ec64 |49bc419a |1 |0 |20213|320|50 |2316|0 |167|100079|16 |1.0 |57.0 |56.0 |3.0 |0.0 |0.0 |0.0 |0.0 |30.0 |563.0 |
|10003763177308262205|0 |14102814|1002|0 |7971d583|c4e18dd6 |50e219e0 |ecad2386|7801e8d9 |07d7df22 |fffcf8a4 |f615f762 |a5df7413 |0 |0 |23441|320|50 |2685|1 |33 |-1 |212|0.0 |408.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1003.0 |5471.0 |982.0 |
|10005435104591133943|0 |14102719|1005|0 |85f751fd|c4e18dd6 |50e219e0 |92f5800b|ae637522 |0f2161f8 |a99f214a |8f2784a2 |0bcabeaf |1 |3 |21189|320|50 |2424|1 |161|100193|71 |0.0 |0.0 |0.0 |0.0 |1.0 |2.0 |1.0 |0.0 |4207.0 |19.0 |
|10006076676750034840|0 |14102522|1005|1 |e151e245|7e091613 |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |dc88197f |fce66524 |1 |0 |4687 |320|50 |423 |2 |39 |100148|32 |0.0 |2.0 |2.0 |1.0 |0.0 |0.0 |0.0 |0.0 |4109.0 |22.0 |
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
only showing top 5 rows
特征哈希将一组分类或数值特征投射到指定维的特征向量(通常比原始特征空间小很多)。这是使用哈希技巧将特征映射到特征向量中的索引。
val hasher = new FeatureHasher()
.setInputCols("site_id_index","site_domain_index","site_category_index","app_id_index","app_domain_index","app_category_index","device_id_index","device_ip_index","device_model_index","device_type","device_conn_type","C14","C15","C16","C17","C18","C19","C20","C21")
.setOutputCol("feature")
val train_hs = hasher.transform(train_index)
val test_hs = hasher.transform(test_index)
println("特征Hasher编码:")
train_index.show(5,false)
test_index.show(5,false)
特征Hasher编码:
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|id |click|hour |C1 |banner_pos|site_id |site_domain|site_category|app_id |app_domain|app_category|device_id|device_ip|device_model|device_type|device_conn_type|C14 |C15|C16|C17 |C18|C19 |C20 |C21|click_index|site_id_index|site_domain_index|site_category_index|app_id_index|app_domain_index|app_category_index|device_id_index|device_ip_index|device_model_index|
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|10000133892746881176|0 |14102813|1005|0 |85f751fd|c4e18dd6 |50e219e0 |febd1138|82e27996 |0f2161f8 |a99f214a |f5c62586 |b4b19c97 |1 |0 |21611|320|50 |2480|3 |297 |100111|61 |0.0 |0.0 |0.0 |0.0 |4.0 |4.0 |1.0 |0.0 |23751.0 |64.0 |
|10000987464039884177|0 |14102816|1005|0 |5bcf81a2|9d54950b |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |845f69f4 |fa61e8fe |1 |0 |23438|320|50 |2684|2 |1319|-1 |52 |0.0 |11.0 |7.0 |1.0 |0.0 |0.0 |0.0 |0.0 |5237.0 |67.0 |
|10001055656394300907|0 |14102814|1005|0 |85f751fd|c4e18dd6 |50e219e0 |e9739828|df32afa9 |cef3e649 |a99f214a |6454c6ba |ecb851b2 |1 |0 |23441|320|50 |2685|1 |33 |100083|212|0.0 |0.0 |0.0 |0.0 |13.0 |11.0 |2.0 |0.0 |18147.0 |8.0 |
|10001237608243220141|0 |14102701|1005|0 |85f751fd|c4e18dd6 |50e219e0 |febd1138|82e27996 |0f2161f8 |a99f214a |ab986e15 |2ea4f8ba |1 |0 |19743|320|50 |2264|3 |427 |100000|61 |0.0 |0.0 |0.0 |0.0 |4.0 |4.0 |1.0 |0.0 |23941.0 |34.0 |
|10001363001408225332|0 |14102812|1005|1 |85f751fd|c4e18dd6 |50e219e0 |1dc72b4d|2347f47a |0f2161f8 |b7c2e4b6 |bce45090 |5db079b5 |1 |2 |22998|300|50 |2657|3 |35 |100013|23 |0.0 |0.0 |0.0 |0.0 |25.0 |1.0 |1.0 |1760.0 |729.0 |25.0 |
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+----+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
only showing top 5 rows
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|id |click|hour |C1 |banner_pos|site_id |site_domain|site_category|app_id |app_domain|app_category|device_id|device_ip|device_model|device_type|device_conn_type|C14 |C15|C16|C17 |C18|C19|C20 |C21|click_index|site_id_index|site_domain_index|site_category_index|app_id_index|app_domain_index|app_category_index|device_id_index|device_ip_index|device_model_index|
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
|10002333262420133303|0 |14102211|1005|1 |856e6d3f|58a89a43 |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |ac322dfb |0dc22ebc |1 |0 |19771|320|50 |2227|0 |679|100077|48 |0.0 |6.0 |6.0 |1.0 |0.0 |0.0 |0.0 |0.0 |6004.0 |279.0 |
|10002749335348787004|1 |14102800|1005|0 |2a68aa20|9b851bd8 |3e814130 |ecad2386|7801e8d9 |07d7df22 |a99f214a |b4a0ec64 |49bc419a |1 |0 |20213|320|50 |2316|0 |167|100079|16 |1.0 |57.0 |56.0 |3.0 |0.0 |0.0 |0.0 |0.0 |30.0 |563.0 |
|10003763177308262205|0 |14102814|1002|0 |7971d583|c4e18dd6 |50e219e0 |ecad2386|7801e8d9 |07d7df22 |fffcf8a4 |f615f762 |a5df7413 |0 |0 |23441|320|50 |2685|1 |33 |-1 |212|0.0 |408.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1003.0 |5471.0 |982.0 |
|10005435104591133943|0 |14102719|1005|0 |85f751fd|c4e18dd6 |50e219e0 |92f5800b|ae637522 |0f2161f8 |a99f214a |8f2784a2 |0bcabeaf |1 |3 |21189|320|50 |2424|1 |161|100193|71 |0.0 |0.0 |0.0 |0.0 |1.0 |2.0 |1.0 |0.0 |4207.0 |19.0 |
|10006076676750034840|0 |14102522|1005|1 |e151e245|7e091613 |f028772b |ecad2386|7801e8d9 |07d7df22 |a99f214a |dc88197f |fce66524 |1 |0 |4687 |320|50 |423 |2 |39 |100148|32 |0.0 |2.0 |2.0 |1.0 |0.0 |0.0 |0.0 |0.0 |4109.0 |22.0 |
+--------------------+-----+--------+----+----------+--------+-----------+-------------+--------+----------+------------+---------+---------+------------+-----------+----------------+-----+---+---+----+---+---+------+---+-----------+-------------+-----------------+-------------------+------------+----------------+------------------+---------------+---------------+------------------+
only showing top 5 rows
采用spark ml中LR模型,对广告点击进行预测。其中一些设置参数如下:
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0)
.setFeaturesCol("feature")
.setLabelCol("click_index")
.setPredictionCol("click_predict")
val model_lr = lr.fit(train_hs)
println(s"每个特征对应系数: ${model_lr.coefficients} 截距: ${model_lr.intercept}")
val predictions = model_lr.transform(test_hs)
predictions.select("click_index","click_predict","probability").show(10,false)
val predictionRdd = predictions.select("click_predict","click_index").rdd.map{
case Row(click_predict:Double,click_index:Double)=>(click_predict,click_index)
}
val metrics = new MulticlassMetrics(predictionRdd)
val accuracy = metrics.accuracy
val weightedPrecision = metrics.weightedPrecision
val weightedRecall = metrics.weightedRecall
val f1 = metrics.weightedFMeasure
println(s"LR评估结果:\n分类正确率:${accuracy}\n加权正确率:${weightedPrecision}\n加权召回率:${weightedRecall}\nF1值:${f1}")
+-----------+-------------+----------------------------------------+
|click_index|click_predict|probability |
+-----------+-------------+----------------------------------------+
|0.0 |0.0 |[0.8673583515173942,0.13264164848260582]|
|1.0 |0.0 |[0.7065355297971061,0.29346447020289396]|
|0.0 |0.0 |[0.9247213791421071,0.07527862085789287]|
|0.0 |0.0 |[0.9411799267286762,0.05882007327132381]|
|0.0 |0.0 |[0.7534455683444734,0.24655443165552665]|
|0.0 |0.0 |[0.8993737856386326,0.10062621436136741]|
|0.0 |0.0 |[0.8837461636081269,0.11625383639187312]|
|0.0 |0.0 |[0.8320314092251319,0.16796859077486806]|
|0.0 |0.0 |[0.9027137639161569,0.09728623608384318]|
|1.0 |0.0 |[0.8791816482313737,0.12081835176862625]|
+-----------+-------------+----------------------------------------+
only showing top 10 rows
LR评估结果:
分类正确率:0.8308678500986193
加权正确率:0.7886992955593048
加权召回率:0.8308678500986193
F1值:0.7596712330402737
object AdsCtrPredictionLR {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("AdsCtrPredictionLR")
.master("local[2]")
.config("spark.some.config.option", "some-value")
.getOrCreate()
/**
* id和click分别为广告的id和是否点击广告
* site_id,site_domain,site_category,app_id,app_domain,app_category,device_id,device_ip,device_model为分类特征,需要OneHot编码
* device_type,device_conn_type,C14,C15,C16,C17,C18,C19,C20,C21为数值特征,直接使用
*/
val data = spark.read.csv("/opt/data/ads_6M.csv").toDF(
"id","click","hour","C1","banner_pos","site_id","site_domain",
"site_category","app_id","app_domain","app_category","device_id","device_ip",
"device_model","device_type","device_conn_type","C14","C15","C16","C17","C18",
"C19","C20","C21")
data.show(5,false)
val splited = data.randomSplit(Array(0.7,0.3),2L)
val catalog_features = Array("click","site_id","site_domain","site_category","app_id","app_domain","app_category","device_id","device_ip","device_model")
var train_index = splited(0)
var test_index = splited(1)
for(catalog_feature <- catalog_features){
val indexer = new StringIndexer()
.setInputCol(catalog_feature)
.setOutputCol(catalog_feature.concat("_index"))
val train_index_model = indexer.fit(train_index)
val train_indexed = train_index_model.transform(train_index)
val test_indexed = indexer.fit(test_index).transform(test_index,train_index_model.extractParamMap())
train_index = train_indexed
test_index = test_indexed
}
println("字符串编码下标标签:")
train_index.show(5,false)
test_index.show(5,false)
// 特征Hasher
val hasher = new FeatureHasher()
.setInputCols("site_id_index","site_domain_index","site_category_index","app_id_index","app_domain_index","app_category_index","device_id_index","device_ip_index","device_model_index","device_type","device_conn_type","C14","C15","C16","C17","C18","C19","C20","C21")
.setOutputCol("feature")
println("特征Hasher编码:")
val train_hs = hasher.transform(train_index)
val test_hs = hasher.transform(test_index)
/**
* LR建模
* setMaxIter设置最大迭代次数(默认100),具体迭代次数可能在不足最大迭代次数停止(见下一条)
* setTol设置容错(默认1e-6),每次迭代会计算一个误差,误差值随着迭代次数增加而减小,当误差小于设置容错,则停止迭代
* setRegParam设置正则化项系数(默认0),正则化主要用于防止过拟合现象,如果数据集较小,特征维数又多,易出现过拟合,考虑增大正则化系数
* setElasticNetParam正则化范式比(默认0),正则化有两种方式:L1(Lasso)和L2(Ridge),L1用于特征的稀疏化,L2用于防止过拟合
* setLabelCol设置标签列
* setFeaturesCol设置特征列
* setPredictionCol设置预测列
* setThreshold设置二分类阈值
*/
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0)
.setFeaturesCol("feature")
.setLabelCol("click_index")
.setPredictionCol("click_predict")
val model_lr = lr.fit(train_hs)
println(s"每个特征对应系数: ${model_lr.coefficients} 截距: ${model_lr.intercept}")
val predictions = model_lr.transform(test_hs)
predictions.select("click_index","click_predict","probability").show(100,false)
val predictionRdd = predictions.select("click_predict","click_index").rdd.map{
case Row(click_predict:Double,click_index:Double)=>(click_predict,click_index)
}
val metrics = new MulticlassMetrics(predictionRdd)
val accuracy = metrics.accuracy
val weightedPrecision = metrics.weightedPrecision
val weightedRecall = metrics.weightedRecall
val f1 = metrics.weightedFMeasure
println(s"LR评估结果:\n分类正确率:${accuracy}\n加权正确率:${weightedPrecision}\n加权召回率:${weightedRecall}\nF1值:${f1}")
}
}
参考文献
https://blog.csdn.net/xueqingdata/article/details/50578005
https://blog.csdn.net/yhao2014/article/details/60324939
http://spark.apache.org/docs/latest/ml-features.html
http://spark.apache.org/docs/latest/ml-classification-regression.html