RDD转换为DataFrame

RDD转换为DataFrame

今天在使用spark处理movielens数据集时,由于要在数据集中添加一列,所以不能直接读取数据集生成DataFrame,需要在生成DataFrame之前预处理一下数据集添加一列。

所以我就以RDD的方式读入数据,并作相应处理,处理后需要将RDD转换为DataFrame以方便使用ml的API。
将RDD转换为DataFrame有两种方式:

  • 利用java的反射机制。利用反射来推断包含特定类型对象的RDD的schema。这种方法会简化代码并且在你已经知道schema的时候非常适用。

    先创建一个bean类

case class Person(name: String, age: Int)

然后将Rdd转换成DataFrame

val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF()
people.registerTempTable("people")
  • 使用编程接口,构造一个schema并将其应用在已知的RDD上。

先创建一个scheme

val schema = StructType( schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true)))

然后将scheme应用到Rdd上

val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema)

官网描述如下:

When case classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), a DataFrame can be created programmatically with three steps.

  1. Create an RDD of Rows from the original RDD
  2. Create the schema represented by a StructType matching the structure of Rows in the RDD created in Step 1.
  3. Apply the schema to the RDD of Rows via createDataFrame method provided by SparkSession.

http://spark.apache.org/docs/latest/sql-programming-guide.html#programmatically-specifying-the-schema

我用的数据集共有4列,”userId”,”movieId”,”rating”,”timestamp”我要取前3列并加一列favorable表示rating是否大于3

package ml

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._

object movielens {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.ERROR)
    val spark = SparkSession
      .builder
      .appName("MovieLensExample")
      .config("spark.sql.warehouse.dir", "file:///")
      .master("local")
      .getOrCreate()

    val ratings = spark.sparkContext.textFile("F:\\program\\MyPrograms\\data\\ratings.csv")
      .map(_.split(","))
      .map(fields => Row(fields(0),fields(1),fields(2),fields(2).toDouble>3))//.toDF("userId","movieId","rating","Favorable")

    val schema =
      StructType(
        StructField("userId", StringType, true) ::
          StructField("movieId", StringType, true) ::
          StructField("rating",StringType,true) ::
          StructField("Favorable", BooleanType, true) :: Nil)

    val ratingsDF = spark.createDataFrame(ratings,schema)
    ratingsDF.show()

    spark.stop()

  }
}

你可能感兴趣的:(机器学习,spark)