【Spark Java API】Action(2)—fold、countByKey

fold


官方文档描述:

Aggregate the elements of each partition, and then the results for all the partitions, 
using a given associative and commutative function and a neutral "zero value". 
The function op(t1, t2) is allowed to modify t1 and return it as its result value 
to avoid object allocation; however, it should not modify t2.

函数原型:

def fold(zeroValue: T)(f: JFunction2[T, T, T]): T

**
fold是aggregate的简化,将aggregate中的seqOp和combOp使用同一个函数op。
**

源码分析:

def fold(zeroValue: T)(op: (T, T) => T): T = withScope {  
  // Clone the zero value since we will also be serializing it as part of tasks 
  var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())  
  val cleanOp = sc.clean(op)  
  val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)  
  val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)  
  sc.runJob(this, foldPartition, mergeResult)  
  jobResult
}

**
从源码中可以看出,先是将zeroValue赋值给jobResult,然后针对每个分区利用op函数与zeroValue进行计算,再利用op函数将taskResult和jobResult合并计算,同时更新jobResult,最后,将jobResult的结果返回。
**

实例:

List data = Arrays.asList("5", "1", "1", "3", "6", "2", "2");
JavaRDD javaRDD = javaSparkContext.parallelize(data,5);
JavaRDD partitionRDD = javaRDD.mapPartitionsWithIndex(new Function2, Iterator>() {    
  @Override    
  public Iterator call(Integer v1, Iterator v2) throws Exception {        
    LinkedList linkedList = new LinkedList();        
    while(v2.hasNext()){            
      linkedList.add(v1 + "=" + v2.next());        
    }        
    return linkedList.iterator();    
  }
},false);

System.out.println(partitionRDD.collect());

String foldRDD = javaRDD.fold("0", new Function2() {    
    @Override    
    public String call(String v1, String v2) throws Exception {        
        return v1 + " - " + v2;    
    }
});
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" + foldRDD);

countByKey


官方文档描述:

Count the number of elements for each key, collecting the results to a local Map.
Note that this method should only be used if the resulting map is expected to be small, 
as the whole thing is loaded into the driver's memory. To handle very large results, 
consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), 
which returns an RDD[T, Long] instead of a map.

函数原型:

def countByKey(): java.util.Map[K, Long]

源码分析:

def countByKey(): Map[K, Long] = self.withScope {  
   self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap
}

**
从源码中可以看出,先是进行map操作转化为(key,1)键值对,再进行reduce聚合操作,最后利用collect函数将数据加载到driver,并转化为map类型。
注意,从上述分析可以看出,countByKey操作将数据全部加载到driver端的内存,如果数据量比较大,可能出现OOM。因此,如果key数量比较多,建议进行rdd.mapValues(_ => 1L).reduceByKey(_ + _),返回RDD[T, Long]
**

实例:

List data = Arrays.asList("5", "1", "1", "3", "6", "2", "2");
JavaRDD javaRDD = javaSparkContext.parallelize(data,5);

JavaRDD partitionRDD = javaRDD.mapPartitionsWithIndex(new Function2, Iterator>() {    
  @Override      
  public Iterator call(Integer v1, Iterator v2) throws Exception {        
    LinkedList linkedList = new LinkedList();        
    while(v2.hasNext()){            
      linkedList.add(v1 + "=" + v2.next());        
    }        
    return linkedList.iterator();    
  }
},false);
System.out.println(partitionRDD.collect());
JavaPairRDD javaPairRDD = javaRDD.mapToPair(new PairFunction() {    
   @Override    
    public Tuple2 call(String s) throws Exception {        
      return new Tuple2(s,s);    
  }
});
System.out.println(javaPairRDD.countByKey());

你可能感兴趣的:(【Spark Java API】Action(2)—fold、countByKey)