本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告。
训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:https://www.kaggle.com/c/avazu-ctr-prediction/data
数据格式如下:
包含24个字段:
• 1-id: ad identifier
• 2-click: 0/1 for non-click/click
• 3-hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
• 4-C1 — anonymized categorical variable
• 5-banner_pos
• 6-site_id
• 7-site_domain
• 8-site_category
• 9-app_id
• 10-app_domain
• 11-app_category
• 12-device_id
• 13-device_ip
• 14-device_model
• 15-device_type
• 16-device_conn_type
• 17~24—C14-C21 — anonymized categorical variables
其中5到15列为分类特征,16~24列为数值型特征。
Spark代码如下:
1. package com.lxw1234.test
2.
3. import scala.collection.mutable.ListBuffer
4. import scala.collection.mutable.ArrayBuffer
5.
6. import org.apache.spark.SparkContext
7. import org.apache.spark.SparkContext._
8. import org.apache.spark.SparkConf
9. import org.apache.spark.rdd.RDD
10.
11. import org.apache.spark.mllib.classification.NaiveBayes
12. import org.apache.spark.mllib.regression.LabeledPoint
13. import org.apache.spark.mllib.linalg.Vectors
14.
15. import org.apache.spark.mllib.tree.GradientBoostedTrees
16. import org.apache.spark.mllib.tree.configuration.BoostingStrategy
17. import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
18.
19. /**
20. * By: lxw
21. * http://lxw1234.com
22. */
23. object CtrPredict {
24.
25. //input (1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9)
26. //output ((0:1fbe01fe),(1:f3845767),(2:28905ebd),(3:ecad2386),(4:7801e8d9))
27. def parseCatFeatures(catfeatures: Array[String]) : List[(Int, String)] = {
28. var catfeatureList = new ListBuffer(Int, String)
29. for (i <- 0 until catfeatures.length){
30. catfeatureList += i -> catfeatures(i).toString
31. }
32. catfeatureList.toList
33. }
34.
35. def main(args: Array[String]) {
36. val conf = new SparkConf().setMaster(“yarn-client”)
37. val sc = new SparkContext(conf)
38.
39. var ctrRDD = sc.textFile(“/tmp/lxw1234/sample.txt”,10);
40. println(“Total records : ” + ctrRDD.count)
41.
42. //将整个数据集80%作为训练数据,20%作为测试数据集
43. var train_test_rdd = ctrRDD.randomSplit(Array(0.8, 0.2), seed = 37L)
44. var train_raw_rdd = train_test_rdd(0)
45. var test_raw_rdd = train_test_rdd(1)
46.
47. println(“Train records : ” + train_raw_rdd.count)
48. println(“Test records : ” + test_raw_rdd.count)
49.
50. //cache train, test
51. train_raw_rdd.cache()
52. test_raw_rdd.cache()
53.
54. var train_rdd = train_raw_rdd.map{ line =>
55. var tokens = line.split(“,”,-1)
56. //key为id和是否点击广告
57. var catkey = tokens(0) + “::” + tokens(1)
58. //第6列到第15列为分类特征,需要One-Hot-Encoding
59. var catfeatures = tokens.slice(5, 14)
60. //第16列到24列为数值特征,直接使用
61. var numericalfeatures = tokens.slice(15, tokens.size-1)
62. (catkey, catfeatures, numericalfeatures)
63. }
64.
65. //拿一条出来看看
66. train_rdd.take(1)
67. //scala> train_rdd.take(1)
68. //res6: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,Array(1fbe01fe,
69. // f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
70. // Array(2, 15706, 320, 50, 1722, 0, 35, -1)))
71.
72. //将分类特征先做特征ID映射
73. var train_cat_rdd = train_rdd.map{
74. x => parseCatFeatures(x._2)
75. }
76.
77. train_cat_rdd.take(1)
78. //scala> train_cat_rdd.take(1)
79. //res12: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),
80. // (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))
81.
82. //将train_cat_rdd中的(特征ID:特征)去重,并进行编号
83. var oheMap = train_cat_rdd.flatMap(x => x).distinct().zipWithIndex().collectAsMap()
84. //oheMap: scala.collection.Map[(Int, String),Long] = Map((7,608511e9) -> 31527, (7,b2d8fbed) -> 42207,
85. // (7,1d3e2fdb) -> 52791
86. println(“Number of features”)
87. println(oheMap.size)
88.
89. //create OHE for train data
90. var ohe_train_rdd = train_rdd.map{ case (key, cateorical_features, numerical_features) =>
91. var cat_features_indexed = parseCatFeatures(cateorical_features)
92. var cat_feature_ohe = new ArrayBuffer[Double]
93. for (k <- cat_features_indexed) {
94. if(oheMap contains k){
95. cat_feature_ohe += (oheMap get (k)).get.toDouble
96. }else {
97. cat_feature_ohe += 0.0
98. }
99. }
100. var numerical_features_dbl = numerical_features.map{
101. x =>
102. var x1 = if (x.toInt < 0) “0” else x
103. x1.toDouble
104. }
105. var features = cat_feature_ohe.toArray ++ numerical_features_dbl
106. LabeledPoint(key.split(“::”)(1).toInt, Vectors.dense(features))
107. }
108.
109. ohe_train_rdd.take(1)
110. //res15: Array[org.apache.spark.mllib.regression.LabeledPoint] =
111. // Array((0.0,[43127.0,50023.0,57445.0,13542.0,31092.0,14800.0,23414.0,54121.0,
112. // 17554.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))
113.
114. //训练模型
115. //val boostingStrategy = BoostingStrategy.defaultParams(“Regression”)
116. val boostingStrategy = BoostingStrategy.defaultParams(“Classification”)
117. boostingStrategy.numIterations = 100
118. boostingStrategy.treeStrategy.numClasses = 2
119. boostingStrategy.treeStrategy.maxDepth = 10
120. boostingStrategy.treeStrategy.categoricalFeaturesInfo = MapInt, Int
121.
122.
123. val model = GradientBoostedTrees.train(ohe_train_rdd, boostingStrategy)
124. //保存模型
125. model.save(sc, “/tmp/myGradientBoostingClassificationModel”)
126. //加载模型
127. val sameModel = GradientBoostedTreesModel.load(sc,”/tmp/myGradientBoostingClassificationModel”)
128.
129. //将测试数据集做OHE
130. var test_rdd = test_raw_rdd.map{ line =>
131. var tokens = line.split(“,”)
132. var catkey = tokens(0) + “::” + tokens(1)
133. var catfeatures = tokens.slice(5, 14)
134. var numericalfeatures = tokens.slice(15, tokens.size-1)
135. (catkey, catfeatures, numericalfeatures)
136. }
137.
138. var ohe_test_rdd = test_rdd.map{ case (key, cateorical_features, numerical_features) =>
139. var cat_features_indexed = parseCatFeatures(cateorical_features)
140. var cat_feature_ohe = new ArrayBuffer[Double]
141. for (k <- cat_features_indexed) {
142. if(oheMap contains k){
143. cat_feature_ohe += (oheMap get (k)).get.toDouble
144. }else {
145. cat_feature_ohe += 0.0
146. }
147. }
148. var numerical_features_dbl = numerical_features.map{x =>
149. var x1 = if (x.toInt < 0) “0” else x
150. x1.toDouble}
151. var features = cat_feature_ohe.toArray ++ numerical_features_dbl
152. LabeledPoint(key.split(“::”)(1).toInt, Vectors.dense(features))
153. }
154.
155. //验证测试数据集
156. var b = ohe_test_rdd.map {
157. y => var s = model.predict(y.features)
158. (s,y.label,y.features)
159. }
160.
161. b.take(10).foreach(println)
162.
163. //预测准确率
164. var predictions = ohe_test_rdd.map(lp => sameModel.predict(lp.features))
165. predictions.take(10).foreach(println)
166. var predictionAndLabel = predictions.zip( ohe_test_rdd.map(_.label))
167. var accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2 ).count/ohe_test_rdd.count
168. println(“GBTR accuracy ” + accuracy)
169. //GBTR accuracy 0.8227084119200302
170.
171. }
172.
173. }
174.
其中,训练数据集: Train records : 104558, 测试数据集:Test records : 26510
程序主要输出:
1. scala> train_rdd.take(1)
2. res23: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,
3. Array(1fbe01fe, f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),
4. Array(2, 15706, 320, 50, 1722, 0, 35, -1)))
5.
6.
7. scala> train_cat_rdd.take(1)
8. res24: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),
9. (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))
10.
11.
12. scala> println(“Number of features”)
13. Number of features
14.
15. scala> println(oheMap.size)
16. 57606
17.
18.
19. scala> ohe_train_rdd.take(1)
20. res27: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array(
21. (0.0,[11602.0,22813.0,11497.0,16828.0,30657.0,23893.0,13182.0,31723.0,39722.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))
22.
23.
24. scala> println(“GBTR accuracy ” + accuracy)
25. GBTR accuracy 0.8227084119200302