Java语言Spark实现KMeans

写在前面
Spark程序多是Scala语言编写,Scala语法简单,但是对于初学者,无法知道变量类型,不清晰实现细节,所以我使用Java语言编写Spark程序,旨在熟悉RDD算子的编程方法。以KMeans算法为例,管中窥豹,了解如何使用RDD算子编写Spark程序。
本文先讲解使用到的RDD算子,最后附完整代码。

1. map(func)

对RDD中的每个记录都使用func进行转换,返回一个新的RDD。
例,把从文件中读取的每个String类型的记录(每行一个坐标经纬度)转成List类型。

JavaRDD<List<Double>> kPoints = sc.textFile("E:/data/kmeans/center.txt").map(new Function<String, List<Double>>() {
    @Override
    public List<Double> call(String s) throws Exception {
        String[] a = s.split(",");
        ArrayList<Double> ret = new ArrayList<>();
        ret.add(Double.parseDouble(a[0]));
        ret.add(Double.parseDouble(a[1]));
        return ret;
    }
});

2. mapToPair(func)

对RDD中的每个记录都使用func进行转换,返回一个新的键值对RDD(Java中是JavaPairRDD),使用Tuple2<>类型实现。
例,把每一个点(List类型)转换成键值对,其中该点所在的聚类中心点(由closestPoint()函数求得)为键,该点坐标及属于该聚类中心的点的个数(对于当前点个数是1)为值。

JavaPairRDD<Integer, Pair<List<Integer>, Integer>> closet = points.mapToPair(new PairFunction<List<Integer>, Integer, Pair<List<Integer>, Integer>>() {
    @Override
    public Tuple2<Integer, Pair<List<Integer>, Integer>> call(List<Integer> point) throws Exception {
        return new Tuple2<>(closestPoint(point, kPoints3), new Pair<>(point, 1));
    }
});

3. reduceByKey(func)

对键值对RDD中的每个记录,按照func进行合并转换,返回相同类型的键值对RDD。
例,对属于相同聚类中心的点的坐标分别求和(addPoints()函数实现),计数求和。

JavaPairRDD<Integer, Pair<List<Integer>, Integer>> newPoints = closet.reduceByKey(new Function2<Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>>() {
    @Override
    public Pair<List<Integer>, Integer> call(Pair<List<Integer>, Integer> t1, Pair<List<Integer>, Integer> t2) throws Exception {
        return new Pair<>(addPoints(t1.getKey(), t2.getKey()), t1.getValue() + t2.getValue());
    }
});

4. foreach(func)

对RDD中的每个记录都调用func,func函数没有返回值。
例,遍历输出JavaPairRDD>变量(Tuple2>类型)。

closet.foreach(new VoidFunction<Tuple2<Integer, Pair<List<Integer>, Integer>>>() {
    @Override
    public void call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
        Integer index = t._1();
        List<Integer> point = t._2().getKey();
        Integer t2 = t._2().getValue();
        System.out.print(index + " ");
        System.out.print(point.get(0) + "," + point.get(1) + " ");
        System.out.println(t2);
    }
});

5. collect()

将RDD记录转成List<>
collect()算子是Action操作,但是有以下缺点:(1)一次collect需要一次Shuffle,非常耗时;(2)collect操作会将分布在各个节点的数据存到Driver节点,占用内存。

6. cache()

迭代计算经常需要多次重复使用同一组数据,而RDD的惰性机制,是每次遇到Action操作都要重新计算,所以使用cache()后,第一次遇到Action操作计算后,会把计算结果保留在内存中,这就是cache()的持久化或缓存机制。

7. 关于Java匿名类进行函数传递

在使用map()、mapToPair()、reduceByKey()时,传入的参数是匿名类,需要注意如果使用到外部变量,需要是final类型。例,mapToPair()例子中的kPoints3变量。

8. 完整程序

import javafx.util.Pair;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.List;

public class KMeans {
    public static double distanceSquared(List<Integer> point1, List<Double> point2) {
        double sum = 0.0;
        for (int i = 0; i < point1.size(); i++) {
            sum += Math.pow(point1.get(i).doubleValue() - point2.get(i), 2);
        }
        return sum;
    }

    private static List<Integer> addPoints(List<Integer> p1, List<Integer> p2) {
        ArrayList<Integer> ret = new ArrayList<>();
        for (int i = 0; i < p1.size(); i++) {
            ret.add(p1.get(i) + p2.get(i));
        }
        return ret;
    }

    private static Integer closestPoint(List<Integer> point, List<List<Double>> kPoints2) {
        int bestIndex = 0;
        double cloest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < kPoints2.size(); i++) {
            double dist = distanceSquared(point, kPoints2.get(i));
            if (dist < cloest) {
                cloest = dist;
                bestIndex = i;
            }
        }
        return bestIndex;
    }

    public static void run() {
        SparkConf conf = new SparkConf().setAppName("KMeans").setMaster("local");
        JavaSparkContext sc = new JavaSparkContext(conf);
        int iterateNum = 20;
        JavaRDD<List<Integer>> points = sc.textFile("E:/data/kmeans/data.txt").map(new Function<String, List<Integer>>() {

            @Override
            public List<Integer> call(String s) throws Exception {
                String[] a = s.split(",");
                ArrayList<Integer> ret = new ArrayList<>();
                ret.add(Integer.parseInt(a[0]));
                ret.add(Integer.parseInt(a[1]));
                return ret;
            }
        }).cache();
        points.foreach(new VoidFunction<List<Integer>>() {

            @Override
            public void call(List<Integer> point) throws Exception {
                System.out.println(point.get(0) + "," + point.get(1));
            }
        });
        JavaRDD<List<Double>> kPoints = sc.textFile("E:/data/kmeans/center.txt").map(new Function<String, List<Double>>() {
            @Override
            public List<Double> call(String s) throws Exception {
                String[] a = s.split(",");
                ArrayList<Double> ret = new ArrayList<>();
                ret.add(Double.parseDouble(a[0]));
                ret.add(Double.parseDouble(a[1]));
                return ret;
            }
        });
        List<List<Double>> kPoints2 = kPoints.collect();
        for (int iter = 0; iter < iterateNum; iter++) {
            final List<List<Double>> kPoints3 = new ArrayList<>(kPoints2);
            JavaPairRDD<Integer, Pair<List<Integer>, Integer>> closet = points.mapToPair(new PairFunction<List<Integer>, Integer, Pair<List<Integer>, Integer>>() {
                @Override
                public Tuple2<Integer, Pair<List<Integer>, Integer>> call(List<Integer> point) throws Exception {
                    return new Tuple2<>(closestPoint(point, kPoints3), new Pair<>(point, 1));
                }
            });
            JavaPairRDD<Integer, Pair<List<Integer>, Integer>> newPoints = closet.reduceByKey(new Function2<Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>, Pair<List<Integer>, Integer>>() {
                @Override
                public Pair<List<Integer>, Integer> call(Pair<List<Integer>, Integer> t1, Pair<List<Integer>, Integer> t2) throws Exception {
                    return new Pair<>(addPoints(t1.getKey(), t2.getKey()), t1.getValue() + t2.getValue());
                }
            });
            JavaRDD<List<Double>> newPoints2 = newPoints.map(new Function<Tuple2<Integer, Pair<List<Integer>, Integer>>, List<Double>>() {
                @Override
                public List<Double> call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
                    Integer n = t._2().getValue();
                    List<Integer> point = t._2().getKey();
                    ArrayList<Double> newPoint = new ArrayList<>();
                    for (int i = 0; i < point.size(); i++) {
                        newPoint.add(0.0);
                    }
                    for (int i = 0; i < point.size(); i++) {
                        newPoint.set(i, newPoint.get(i) + point.get(i).doubleValue() / n);
                    }
                    return newPoint;
                }
            });
            List<List<Double>> newPoints3 = newPoints2.collect();
            kPoints2 = new ArrayList<>(newPoints3);

            if (iter == iterateNum - 1) {
                closet.foreach(new VoidFunction<Tuple2<Integer, Pair<List<Integer>, Integer>>>() {
                    @Override
                    public void call(Tuple2<Integer, Pair<List<Integer>, Integer>> t) throws Exception {
                        Integer index = t._1();
                        List<Integer> point = t._2().getKey();
                        Integer t2 = t._2().getValue();
                        System.out.print(index + " ");
                        System.out.print(point.get(0) + "," + point.get(1) + " ");
                        System.out.println(t2);
                    }
                });
            }
        }
        sc.stop();
    }


    public static void main(String[] args) {
        run();
    }
}

其中,data.txt文件为数据集,即待分类样本点,内容如下

0,0
1,2
3,1
8,8
9,10
10,7

center.txt文件为初始聚类中心,内容如下

1,2
3,1

你可能感兴趣的:(分布式,spark,java,kmeans)