今天在使用spark处理movielens数据集时,由于要在数据集中添加一列,所以不能直接读取数据集生成DataFrame,需要在生成DataFrame之前预处理一下数据集添加一列。
所以我就以RDD的方式读入数据,并作相应处理,处理后需要将RDD转换为DataFrame以方便使用ml的API。
将RDD转换为DataFrame有两种方式:
利用java的反射机制。利用反射来推断包含特定类型对象的RDD的schema。这种方法会简化代码并且在你已经知道schema的时候非常适用。
先创建一个bean类
case class Person(name: String, age: Int)
然后将Rdd转换成DataFrame
val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF()
people.registerTempTable("people")
先创建一个scheme
val schema = StructType( schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))
然后将scheme应用到Rdd上
val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema)
官网描述如下:
When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), a DataFrame can be created programmatically with three steps.
http://spark.apache.org/docs/latest/sql-programming-guide.html#programmatically-specifying-the-schema
我用的数据集共有4列,”userId”,”movieId”,”rating”,”timestamp”我要取前3列并加一列favorable表示rating是否大于3
package ml
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
object movielens {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val spark = SparkSession
.builder
.appName("MovieLensExample")
.config("spark.sql.warehouse.dir", "file:///")
.master("local")
.getOrCreate()
val ratings = spark.sparkContext.textFile("F:\\program\\MyPrograms\\data\\ratings.csv")
.map(_.split(","))
.map(fields => Row(fields(0),fields(1),fields(2),fields(2).toDouble>3))//.toDF("userId","movieId","rating","Favorable")
val schema =
StructType(
StructField("userId", StringType, true) ::
StructField("movieId", StringType, true) ::
StructField("rating",StringType,true) ::
StructField("Favorable", BooleanType, true) :: Nil)
val ratingsDF = spark.createDataFrame(ratings,schema)
ratingsDF.show()
spark.stop()
}
}