写在前面
Spark程序多是Scala语言编写,Scala语法简单,但是对于初学者,无法知道变量类型,不清晰实现细节,所以我使用Java语言编写Spark程序,旨在熟悉RDD算子的编程方法。以KMeans算法为例,管中窥豹,了解如何使用RDD算子编写Spark程序。
本文先讲解使用到的RDD算子,最后附完整代码。
对RDD中的每个记录都使用func进行转换,返回一个新的RDD。
例,把从文件中读取的每个String类型的记录(每行一个坐标经纬度)转成List类型。
JavaRDD<List<Double>> kPoints = sc.textFile("E:/data/kmeans/center.txt").map(new Function<String, List<Double>>() {
@Override
public List<Double> call(String s) throws Exception {
String[] a = s.split(",");
ArrayList<Double> ret = new ArrayList<>();
ret.add(Double.parseDouble(a[0]));
ret.add(Double.parseDouble(a[1]));
return ret;
}
});
对RDD中的每个记录都使用func进行转换,返回一个新的键值对RDD(Java中是JavaPairRDD),使用Tuple2<>类型实现。
例,把每一个点(List类型)转换成键值对,其中该点所在的聚类中心点(由closestPoint()函数求得)为键,该点坐标及属于该聚类中心的点的个数(对于当前点个数是1)为值。
JavaPairRDD<Integer, Pair<List<Integer>, Integer>> closet = points.mapToPair(new PairFunction<List<Integer>, Integer, Pair<List<Integer>, Integer>>() {
@Override
public Tuple2<Integer, Pair<List<Integer>, Integer>> call(List<Integer> point) throws Exception {
return new Tuple2<>(closestPoint(point, kPoints3), new Pair<>(point, 1));
}
});
对键值对RDD中的每个记录,按照func进行合并转换,返回相同类型的键值对RDD。
例,对属于相同聚类中心的点的坐标分别求和(addPoints()函数实现),计数求和。
JavaPairRDD<Integer, Pair<List<Integer>, Integer>> newPoints = closet.reduceByKey(new Function2<Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>>() {
@Override
public Pair<List<Integer>, Integer> call(Pair<List<Integer>, Integer> t1, Pair<List<Integer>, Integer> t2) throws Exception {
return new Pair<>(addPoints(t1.getKey(), t2.getKey()), t1.getValue() + t2.getValue());
}
});
对RDD中的每个记录都调用func,func函数没有返回值。
例,遍历输出JavaPairRDD
closet.foreach(new VoidFunction<Tuple2<Integer, Pair<List<Integer>, Integer>>>() {
@Override
public void call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
Integer index = t._1();
List<Integer> point = t._2().getKey();
Integer t2 = t._2().getValue();
System.out.print(index + " ");
System.out.print(point.get(0) + "," + point.get(1) + " ");
System.out.println(t2);
}
});
将RDD记录转成List<>
collect()算子是Action操作,但是有以下缺点:(1)一次collect需要一次Shuffle,非常耗时;(2)collect操作会将分布在各个节点的数据存到Driver节点,占用内存。
迭代计算经常需要多次重复使用同一组数据,而RDD的惰性机制,是每次遇到Action操作都要重新计算,所以使用cache()后,第一次遇到Action操作计算后,会把计算结果保留在内存中,这就是cache()的持久化或缓存机制。
在使用map()、mapToPair()、reduceByKey()时,传入的参数是匿名类,需要注意如果使用到外部变量,需要是final类型。例,mapToPair()例子中的kPoints3变量。
import javafx.util.Pair;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.List;
public class KMeans {
public static double distanceSquared(List<Integer> point1, List<Double> point2) {
double sum = 0.0;
for (int i = 0; i < point1.size(); i++) {
sum += Math.pow(point1.get(i).doubleValue() - point2.get(i), 2);
}
return sum;
}
private static List<Integer> addPoints(List<Integer> p1, List<Integer> p2) {
ArrayList<Integer> ret = new ArrayList<>();
for (int i = 0; i < p1.size(); i++) {
ret.add(p1.get(i) + p2.get(i));
}
return ret;
}
private static Integer closestPoint(List<Integer> point, List<List<Double>> kPoints2) {
int bestIndex = 0;
double cloest = Double.POSITIVE_INFINITY;
for (int i = 0; i < kPoints2.size(); i++) {
double dist = distanceSquared(point, kPoints2.get(i));
if (dist < cloest) {
cloest = dist;
bestIndex = i;
}
}
return bestIndex;
}
public static void run() {
SparkConf conf = new SparkConf().setAppName("KMeans").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
int iterateNum = 20;
JavaRDD<List<Integer>> points = sc.textFile("E:/data/kmeans/data.txt").map(new Function<String, List<Integer>>() {
@Override
public List<Integer> call(String s) throws Exception {
String[] a = s.split(",");
ArrayList<Integer> ret = new ArrayList<>();
ret.add(Integer.parseInt(a[0]));
ret.add(Integer.parseInt(a[1]));
return ret;
}
}).cache();
points.foreach(new VoidFunction<List<Integer>>() {
@Override
public void call(List<Integer> point) throws Exception {
System.out.println(point.get(0) + "," + point.get(1));
}
});
JavaRDD<List<Double>> kPoints = sc.textFile("E:/data/kmeans/center.txt").map(new Function<String, List<Double>>() {
@Override
public List<Double> call(String s) throws Exception {
String[] a = s.split(",");
ArrayList<Double> ret = new ArrayList<>();
ret.add(Double.parseDouble(a[0]));
ret.add(Double.parseDouble(a[1]));
return ret;
}
});
List<List<Double>> kPoints2 = kPoints.collect();
for (int iter = 0; iter < iterateNum; iter++) {
final List<List<Double>> kPoints3 = new ArrayList<>(kPoints2);
JavaPairRDD<Integer, Pair<List<Integer>, Integer>> closet = points.mapToPair(new PairFunction<List<Integer>, Integer, Pair<List<Integer>, Integer>>() {
@Override
public Tuple2<Integer, Pair<List<Integer>, Integer>> call(List<Integer> point) throws Exception {
return new Tuple2<>(closestPoint(point, kPoints3), new Pair<>(point, 1));
}
});
JavaPairRDD<Integer, Pair<List<Integer>, Integer>> newPoints = closet.reduceByKey(new Function2<Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>>() {
@Override
public Pair<List<Integer>, Integer> call(Pair<List<Integer>, Integer> t1, Pair<List<Integer>, Integer> t2) throws Exception {
return new Pair<>(addPoints(t1.getKey(), t2.getKey()), t1.getValue() + t2.getValue());
}
});
JavaRDD<List<Double>> newPoints2 = newPoints.map(new Function<Tuple2<Integer, Pair<List<Integer>, Integer>>, List<Double>>() {
@Override
public List<Double> call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
Integer n = t._2().getValue();
List<Integer> point = t._2().getKey();
ArrayList<Double> newPoint = new ArrayList<>();
for (int i = 0; i < point.size(); i++) {
newPoint.add(0.0);
}
for (int i = 0; i < point.size(); i++) {
newPoint.set(i, newPoint.get(i) + point.get(i).doubleValue() / n);
}
return newPoint;
}
});
List<List<Double>> newPoints3 = newPoints2.collect();
kPoints2 = new ArrayList<>(newPoints3);
if (iter == iterateNum - 1) {
closet.foreach(new VoidFunction<Tuple2<Integer, Pair<List<Integer>, Integer>>>() {
@Override
public void call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
Integer index = t._1();
List<Integer> point = t._2().getKey();
Integer t2 = t._2().getValue();
System.out.print(index + " ");
System.out.print(point.get(0) + "," + point.get(1) + " ");
System.out.println(t2);
}
});
}
}
sc.stop();
}
public static void main(String[] args) {
run();
}
}
其中,data.txt文件为数据集,即待分类样本点,内容如下
0,0
1,2
3,1
8,8
9,10
10,7
center.txt文件为初始聚类中心,内容如下
1,2
3,1