在使用reduceByKey,groupByKey算子时,都是针对PairRDD进行操作,那么,我们就可以PairRDD的每个元素的Key加上一个随机数前缀,这样的话,之前存在的大量相同而导致数据倾斜问题的Key就会被重新打散,从而避免数据倾斜。
在进行第一轮聚合之前,先把原先的Key加上一个随机数前缀(10以内的就可以),然后对随机的Key进行聚合操作,这是可以看到,之前相同的Key都会被分到一个Task中处理,现在的话,就会被分配到更多的Task中处理。第一轮聚合完成之后,再把每个Key的随机前缀去掉,恢复成原始的样子,最后进行一次全局聚合。
代码实现:
package cn.spark.core.common;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Random;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
/**
* Data Skew Solution
*
*/
public class DataSkew {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("DataSkew");
JavaSparkContext sc = new JavaSparkContext(conf);
// create initial RDD
JavaRDD initRDD = sc.textFile(args[0]);
// transform initRDD into pairRDD
JavaPairRDD pairRDD = initRDD.mapToPair(
new PairFunction() {
private static final long serialVersionUID = 2479906636617428526L;
@Override
public Tuple2 call(String line) throws Exception {
String[] arr = line.split(",");
String key = arr[0];
Integer value = Integer.valueOf(arr[1]);
return new Tuple2(key, value);
}
});
// add random prefix from pairRDD
JavaPairRDD prePairRDD = pairRDD.mapToPair(
new PairFunction, String, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(Tuple2 tuple) throws Exception {
Random random = new Random();
int prefix = random.nextInt(10);
String key = prefix+"_"+tuple._1;
return new Tuple2(key, tuple._2);
}
});
// reduceByKey
JavaPairRDD tempPrePairRDD = prePairRDD.reduceByKey(
new Function2() {
private static final long serialVersionUID = 2021476568518204795L;
@Override
public Integer call(Integer value1, Integer value2) throws Exception {
return value1 + value2;
}
});
// split Key
JavaPairRDD initPairRDD = tempPrePairRDD.mapToPair(
new PairFunction, String, Integer>() {
private static final long serialVersionUID = -178978937197684290L;
@Override
public Tuple2 call(Tuple2 tuple) throws Exception {
String key = tuple._1.split("_")[1];
return new Tuple2(key, tuple._2);
}
});
// reduceByKey
JavaPairRDD resultPairRDD = initPairRDD.reduceByKey(
new Function2() {
private static final long serialVersionUID = -815845668882788529L;
@Override
public Integer call(Integer value1, Integer value2) throws Exception {
return value1 + value2;
}
});
saveToMysql(resultPairRDD, args[1]);
sc.close();
}
/**
* save resultRDD to mysql
*
* @param resultPairRDD
*/
public static void saveToMysql(JavaPairRDD resultPairRDD, String tableName) {
// create SparkSession object
SparkSession spark = SparkSession.builder().getOrCreate();
// create RowRDD
JavaRDD rowRDD = resultPairRDD.map(
new Function, Row>() {
private static final long serialVersionUID = 7659308133806959864L;
@Override
public Row call(Tuple2 tuple) throws Exception {
return RowFactory.create(tuple._1, tuple._2);
}
});
// create Schema
List fields = new ArrayList();
StructField field = null;
field = DataTypes.createStructField("key", DataTypes.StringType, true);
fields.add(field);
field = DataTypes.createStructField("value", DataTypes.IntegerType, true);
fields.add(field);
StructType schema = DataTypes.createStructType(fields);
// create DataFrame
Dataset resultDF = spark.createDataFrame(rowRDD, schema);
// save to mysql
Properties properties = new Properties();
properties.put("driver", "com.mysql.jdbc.Driver");
properties.put("user", "root");
properties.put("password", "hadoop");
resultDF.write().mode("overwrite").jdbc("jdbc:mysql://localhost:3306", tableName, properties);
}
}
spark-submit脚本:
spark-submit \
--class cn.spark.core.common.DataSkew \
--num-executors 1 \
--driver-memory 1000m \
--executor-memory 1000m \
--executor-cores 2 \
--driver-class-path /root/workspace/java/mysql-connector-java.jar \
--jars /root/workspace/java/mysql-connector-java.jar \
/root/workspace/java/spark-java-0.0.1-SNAPSHOT-jar-with-dependencies.jar hdfs:///temp/data/dataskew.txt retail_db.dataskew \