Spark2.X-自定义累加器AccumulatorV2

  • 累加器作为spark的一个共享变量的实现,在用于累加计数计算计算指标的时候可以有效的减少网络的消耗
  1. spark中有一个节点的角色是Master,根据配置文件进行分配,Master节点的职责主要是参与worker节点之间的资源调度。

  2. 参与spark作业计算的是worker节点上的excutor,在最开始会将原始RDD复制到excutor的各个task进程上以供计算。这时候如果task过多,或者原始RDD过大,则会耗费更多的时间在资源复制上。


    累加器可以实现将资源文件复制到每个excutor上,供excutor中的task进程计算使用,减少网络的占用

现在要实现一个累加的效果是,比如有很多个字符串,比如:

ONE:1    TWO:2    THREE:3    ONE:9

想要计算出ONE,TWO,THREE的count和,那么现在可以先自定义一个Accumulator2

import org.apache.spark.util.AccumulatorV2;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class UserDefinedAccumulator extends AccumulatorV2 {
	private String one = "ONE";
	private String two = "TWO";
	private String three = "THREE";
	//将想要计算的字符串拼接起来,并赋初始值,后续针对data进行累加,并返回
	private String data = one + ":0;" + two + ":0;" + three + ":0;";
	//原始状态
	private String zero = data;

	//判断是否是初始状态,直接与原始状态的字符串进行对比
	@Override
	public boolean isZero() {
		return data.equals(zero);
	}

	//复制一个新的累加器
	@Override
	public AccumulatorV2 copy() {
		return new UserDefinedAccumulator();
	}

	//重置,恢复原始状态
	@Override
	public void reset() {
		data = zero;
	}

	//针对传入的字符串,与当前累加器现有的值进行累加
	@Override
	public void add(String v) {
		data = mergeData(v, data, ";");
	}
	
	//将两个累加器的计算结果进行合并
	@Override
	public void merge(AccumulatorV2 other) {
		data = mergeData(other.value(), data, ";");
	}

	//将此累加器的计算值返回
	@Override
	public String value() {
		return data;
	}

	/**
	 * 合并两个字符串
	 * @param data_1 字符串1
	 * @param data_2 字符串2
	 * @param delimit 分隔符
	 * @return 结果
	 */
	private String mergeData(String data_1, String data_2, String delimit) {
		StringBuffer res = new StringBuffer();
		String[] infos_1 = data_1.split(delimit);
		String[] infos_2 = data_2.split(delimit);
		Map map_1 = new HashMap<>();
		Map map_2 = new HashMap<>();
		for (String info : infos_1) {
			String[] kv = info.split(":");
			if (kv.length == 2) {
				String k = kv[0].toUpperCase();
				Integer v = Integer.valueOf(kv[1]);
				map_1.put(k, v);
				continue;
			}
		}
		for (String info : infos_2) {
			String[] kv = info.split(":");
			if (kv.length == 2) {
				String k = kv[0].toUpperCase();
				Integer v = Integer.valueOf(kv[1]);
				map_2.put(k, v);
				continue;
			}
		}
		for (Map.Entry entry : map_1.entrySet()) {
			String key = entry.getKey();
			Integer value = entry.getValue();
			if (map_2.containsKey(key)) {
				value = value + map_2.get(key);
				map_2.remove(key);
			}
			res.append(key + ":" + value + delimit);
		}
		for (Map.Entry entry : map_1.entrySet()) {
			String key = entry.getKey();
			Integer value = entry.getValue();
			if (res.toString().contains(key)) {
				continue;
			}
			res.append(key + ":" + value + delimit);
		}
		if (!map_2.isEmpty()) {
			for (Map.Entry entry : map_2.entrySet()) {
				String key = entry.getKey();
				Integer value = entry.getValue();
				res.append(key + ":" + value + delimit);
			}
		}
		return res.toString().substring(0, res.toString().length() - 1);
	}
}

新建一个测试类:

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;

import java.util.Arrays;
import java.util.Random;

public class CustomerAccumulator {
	public static void main(String[] args) {
		SparkSession sparkSession = SparkSession.builder()
				.master("local[4]")
				.appName("CustomerBroadcast")
				.getOrCreate();
		JavaSparkContext sc = JavaSparkContext.fromSparkContext(sparkSession.sparkContext());
		sc.setLogLevel("ERROR");
		JavaRDD rdd = sc.parallelize(Arrays.asList("ONE", "TWO", "THREE","ONE"));
		UserDefinedAccumulator count = new UserDefinedAccumulator();
		//将累加器进行注册
		sc.sc().register(count, "user_count");
		//随机设置值
		JavaPairRDD pairRDD = rdd.mapToPair((PairFunction) s -> {
			int num = new Random().nextInt(10);
			return new Tuple2<>(s, s + ":" + num);
		});
		//foreach中进行累加
		pairRDD.foreach((VoidFunction>) tuple2 -> {
			System.out.println(tuple2._2);
			count.add(tuple2._2);
		});
		System.out.println("the value of accumulator is:"+count.value());
	}
}

查看运行结果:

Spark2.X-自定义累加器AccumulatorV2_第1张图片

你可能感兴趣的:(spark)