SparkML中导入csv文件并创建DataFrame训练集

val spark = SparkSession.builder().appName("lr").master("local[*]").getOrCreate()

    import spark.implicits._
    val sc = spark.sparkContext
    val data = sc
      .textFile("path.csv")

    val head = data.first()

    val rawRdd: RDD[Array[String]] = data
      .filter(!_.equals(head))
      .map(_.split(","))
    
    val rdd = rawRdd.map(x=>LabeledPoint(x(x.length-1).toDouble,Vectors.dense(x.drop(x.length-1).map(_.toDouble))))
    
    val df = spark.createDataFrame(rdd).toDF("label","features")

你可能感兴趣的:(spark)