代码片段1:
package com.oreilly.learningsparkexamples.scala import org.apache.spark._ import org.eclipse.jetty.client.ContentExchange import org.eclipse.jetty.client.HttpClient object BasicMapPartitions { def main(args: Array[String]) { val master = args.length match { case x: Int if x > 0 => args(0) case _ => "local" } val sc = new SparkContext(master, "BasicMapPartitions", System.getenv("SPARK_HOME")) val input = sc.parallelize(List("KK6JKQ", "Ve3UoW", "kk6jlk", "W6BB")) val result = input.mapPartitions{ signs => val client = new HttpClient() client.start() signs.map {sign => val exchange = new ContentExchange(true); exchange.setURL(s"http://qrzcq.com/call/${sign}") client.send(exchange) exchange }.map{ exchange => exchange.waitForDone(); exchange.getResponseContent() } } println(result.collect().mkString(",")) } }
mapPartitions的参数signs是input这个rdd的一个分区的所有element组成的Iterator
mapPartitions结果是一个分区的所有element被分区处理函数加工后的element组成的Iterator.
mapPartitions函数会对每个分区调用分区函数处理,然后将处理的结果(若干个Iterator)生成新的RDDs
如下这段代码:
package com.oreilly.learningsparkexamples.scala import org.apache.spark._ object BasicAvgMapPartitions { case class AvgCount(var total: Int = 0, var num: Int = 0) { def merge(other: AvgCount): AvgCount = { total += other.total num += other.num this } def merge(input: Iterator[Int]): AvgCount = { input.foreach{elem => total += elem num += 1 } this } def avg(): Float = { total / num.toFloat; } } def main(args: Array[String]) { val master = args.length match { case x: Int if x > 0 => args(0) case _ => "local" } val sc = new SparkContext(master, "BasicAvgMapPartitions", System.getenv("SPARK_HOME")) val input = sc.parallelize(List(1, 2, 3, 4)) val result = input.mapPartitions(partition => // Here we only want to return a single element for each partition, but mapPartitions requires that we wrap our return in an Iterator Iterator(AvgCount(0, 0).merge(partition))) .reduce((x,y) => x.merge(y)) println(result) } }
mapPartitionsWithIndex与mapPartition基本相同,只是在处理函数的参数是一个二元元组,元组的第一个元素是当前处理的分区的index,元组的第二个元素是当前处理的分区元素组成的Iterator