这个案例将实现一个简单的K-Means聚类算法。有必要先简单地介绍下K-Means的算法计算原理。
K-Means均值是一种迭代聚类算法,其工作原理如下:
算法在固定次数的迭代后终止 (本案例采用的),或者簇中心不再怎么移动了,那么也可以终止计算。
这个案例是在二维数据点数据集上实现的。输入文件是纯文本文件,文件格式必须要满足如下格式:
二维点数据集表示为两个由空白字符分隔的双精度值。数据点用换行符分隔。
例如:"1.2 2.3\n5.3 7.2\n"将代表两个点,分别是 (x=1.2, y=2.3)和(x=5.3, y=7.2)。
簇中心将由id和点坐标来呈现。
例如:"1 6.2 3.2\n2 2.9 5.7\n"将代表两个簇中心,分别是(id=1, x=6.2, y=3.2)和(id=2, x=2.9, y=5.7)。
通过本案例我们将主要学习如下知识:
本案例主要是讲解一种应用思维方式,所以用来训练的原始数据不多。主要目的是为了展示效果。
public static class Point implements Serializable {
// x坐标,y坐标
public double x, y;
public Point() {}
public Point(double x, double y) {
this.x = x;
this.y = y;
}
// 点坐标的加法器
public Point add(Point other) {
x += other.x;
y += other.y;
return this;
}
// 点坐标的除法器
public Point div(long val) {
x /= val;
y /= val;
return this;
}
// 计算点之间的欧式距离
public double euclideanDistance(Point other) {
return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
}
public void clear() {
x = y = 0.0;
}
@Override
public String toString() {
return x + " " + y;
}
}
簇中心从物理角度看是称为质心。质心的数据结构代码定义如下:
/**
* 质心类, 基于点坐标和id.
*/
public static class Centroid extends Point {
public int id;
public Centroid() {}
public Centroid(int id, double x, double y) {
super(x, y);
this.id = id;
}
public Centroid(int id, Point p) {
super(p.x, p.y);
this.id = id;
}
@Override
public String toString() {
return id + " " + super.toString();
}
}
簇中心(质心)的定义类,是基于Point的。其由一个质心Id和质心的位置坐标组成。
如果主程序执行时没有指定输入的CSV文件路径,那么就读取默认数据。默认数据的定义如下:
package org.apache.flink.examples.java.clustering.util;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.examples.java.clustering.KMeans.Centroid;
import org.apache.flink.examples.java.clustering.KMeans.Point;
import java.util.LinkedList;
import java.util.List;
/**
* 提供用于K-Means示例程序的默认数据集。如果没有为程序提供参数,则使用默认数据集。
*
*/
public class KMeansData {
/**
* 簇中心(质心)数据
*/
public static final Object[][] CENTROIDS = new Object[][] {
new Object[] {1, -31.85, -44.77},
new Object[]{2, 35.16, 17.46},
new Object[]{3, -5.16, 21.93},
new Object[]{4, -24.06, 6.81}
};
/**
* 输入的点数据
*/
public static final Object[][] POINTS = new Object[][] {
new Object[] {-14.22, -48.01},
new Object[] {-22.78, 37.10},
new Object[] {56.18, -42.99},
new Object[] {35.04, 50.29},
new Object[] {-9.53, -46.26},
new Object[] {-34.35, 48.25},
new Object[] {55.82, -57.49},
new Object[] {21.03, 54.64},
new Object[] {-13.63, -42.26},
new Object[] {-36.57, 32.63},
new Object[] {50.65, -52.40},
new Object[] {24.48, 34.04},
new Object[] {-2.69, -36.02},
new Object[] {-38.80, 36.58},
new Object[] {24.00, -53.74},
new Object[] {32.41, 24.96},
new Object[] {-4.32, -56.92},
new Object[] {-22.68, 29.42},
new Object[] {59.02, -39.56},
new Object[] {24.47, 45.07},
new Object[] {5.23, -41.20},
new Object[] {-23.00, 38.15},
new Object[] {44.55, -51.50},
new Object[] {14.62, 59.06},
new Object[] {7.41, -56.05},
new Object[] {-26.63, 28.97},
new Object[] {47.37, -44.72},
new Object[] {29.07, 51.06},
new Object[] {0.59, -31.89},
new Object[] {-39.09, 20.78},
new Object[] {42.97, -48.98},
new Object[] {34.36, 49.08},
new Object[] {-21.91, -49.01},
new Object[] {-46.68, 46.04},
new Object[] {48.52, -43.67},
new Object[] {30.05, 49.25},
new Object[] {4.03, -43.56},
new Object[] {-37.85, 41.72},
new Object[] {38.24, -48.32},
new Object[] {20.83, 57.85}
};
/**
* 得到默认的质心数据
* @param env
* @return
*/
public static DataSet<Centroid> getDefaultCentroidDataSet(ExecutionEnvironment env) {
List<Centroid> centroidList = new LinkedList<Centroid>();
for (Object[] centroid : CENTROIDS) {
centroidList.add(
new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
}
return env.fromCollection(centroidList);
}
/**
* 得到默认的点数据
* @param env
* @return
*/
public static DataSet<Point> getDefaultPointDataSet(ExecutionEnvironment env) {
List<Point> pointList = new LinkedList<Point>();
for (Object[] point : POINTS) {
pointList.add(new Point((Double) point[0], (Double) point[1]));
}
return env.fromCollection(pointList);
}
}
主程序入口的代码如下,下面将逐步地分析代码的逻辑。
public static void main(String[] args) throws Exception {
// 1.解析命令行参数
final ParameterTool params = ParameterTool.fromArgs(args);
// 2. 构建执行环境
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// 3. 使参数在Web界面中可用
env.getConfig().setGlobalJobParameters(params);
// 4. 得到输入数据:从提供的路径读取点和质心,或返回默认数据
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
// 5. 为K-Means算法设置批量迭代次数
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
// 6. K-Means算法计算过程
DataSet<Centroid> newCentroids = points
// 6.1. 计算每个点距离最近的质心
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// 6.2. 每个簇内的所有点坐标求和
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// 6.3. 根据点计数和坐标和计算新的质心
.map(new CentroidAverager());
// 7. 将新的质心数据反馈到下一个迭代中
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
// 8. 将点归宿给最终的簇
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
// 9. 指定输出结果路径和执行
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
上面是执行主函数逻辑的全部代码。代码注释中,我将逻辑代码注释成了9步。所以下面将主要解释下重要步骤的实现细节。
final ParameterTool params = ParameterTool.fromArgs(args);
任务在提交执行时是可以指定参数的,主要可传参数包括:
当然如果执行时,某些参数不传,那么系统会读取默认的。
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
这里没有什么好说的,就是Flink任务必须要做的事情,初始化执行程序的上下文环境。
env.getConfig().setGlobalJobParameters(params);
这一步的目的也很简单,使参数在WebUI界面中可用。所以也不多说明。
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
这一步是得到输入的数据。输入的数据包含了两部分:点数据集和聚集中心(质心)数据集。
得到点数据集的函数
/**
* 得到输入的点数据集
* @param params
* @param env
* @return
*/
private static DataSet<Point> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Point> points;
// 如果有“points”这个输入参数,则从指定CSV路径中读入点数据源
if (params.has("points")) {
points = env.readCsvFile(params.get("points"))
.fieldDelimiter(" ")
.pojoType(Point.class, "x", "y");
// 否则,读取默认的数据源
} else {
System.out.println("Executing K-Means example with default point data set.");
System.out.println("Use --points to specify file input.");
points = KMeansData.getDefaultPointDataSet(env);
}
return points;
}
得到质心数据集的函数
/**
* 得到质心数据集
* @param params
* @param env
* @return
*/
private static DataSet<Centroid> getCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Centroid> centroids;
// 如果指定了质心数据集的读入csv文件路径,那么就读取。
if (params.has("centroids")) {
centroids = env.readCsvFile(params.get("centroids"))
.fieldDelimiter(" ")
.pojoType(Centroid.class, "id", "x", "y");
// 否则,那么就读取默认数据
} else {
System.out.println("Executing K-Means example with default centroid data set.");
System.out.println("Use --centroids to specify file input.");
centroids = KMeansData.getDefaultCentroidDataSet(env);
}
return centroids;
}
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
这一步是为K-Means算法设置批量迭代次数,默认是迭代10次。
DataSet<Centroid> newCentroids = points
// 6.1. 计算每个点距离最近的质心
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// 6.2. 每个质心的点坐标计数和求和
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// 6.3. 根据点计数和坐标和计算新的质心
.map(new CentroidAverager());
这里是真正迭代运算的计算逻辑。其细节过程是是分步的。因为这里逻辑是算法的核心了,我们有必要细看下。
第一步:计算每个点距离最近的质心
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
这里重点看下SelectNearestCenter类的执行逻辑。
/** 确定数据点最近的群集中心. */
@ForwardedFields("*->1")
public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
private Collection<Centroid> centroids;
/** 将广播变量中的质心数据集读取到集合中*/
@Override
public void open(Configuration parameters) throws Exception {
this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
}
@Override
public Tuple2<Integer, Point> map(Point p) throws Exception {
double minDistance = Double.MAX_VALUE;
int closestCentroidId = -1;
// 遍历所有的簇中心
for (Centroid centroid : centroids) {
// 计算点和簇中心的欧式距离
double distance = p.euclideanDistance(centroid);
// 找到距离点最近的簇中心
if (distance < minDistance) {
minDistance = distance;
closestCentroidId = centroid.id;
}
}
// 输出一条心的记录,由簇中心id和Point组成.
return new Tuple2<>(closestCentroidId, p);
}
}
第二步: 每个簇内的所有点坐标求和
.map(new CountAppender()).groupBy(0)
.reduce(new CentroidAccumulator())
这里重点看下CountAppender类的执行逻辑和CentroidAccumulator类的执行逻辑:
/** 对 Tuple2进行计数 */
@ForwardedFields("f0;f1")
public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
// 对簇内点进行计数
return new Tuple3<>(t.f0, t.f1, 1L);
}
}
/** 对簇内点计数以及对簇内点的坐标进行累加 */
@ForwardedFields("0")
public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
// 这一步逻辑很关键,对簇内点坐标累计,然后对簇内元素个数计数。
return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
}
}
这一步实现了对每个簇内的元素(点)个数进行了计数,然后对簇内的这些点的坐标进行了累加。
第三步: 根据点计数和坐标和计算新的质心
.map(new CentroidAverager());
这里看下CentroidAverager类的逻辑。
/** 从簇内点的个数和这些点的坐标和计算出新的簇中心*/
@ForwardedFields("0->id")
public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {
@Override
public Centroid map(Tuple3<Integer, Point, Long> value) {
// 坐标和/簇内点个数作为新的簇中心
return new Centroid(value.f0, value.f1.div(value.f2));
}
}
这一步是根据上一步计算的簇内元素个数,以及这些元素的坐标和来求得新的簇中心坐标。计算方式是(坐标和/簇内元素个数)。
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
将上一次迭代计算得到的新的簇中心数据newCentroids反馈给loop,然后进行下一次迭代。
其实3.4.5到3.4.7是可以一起看的,这三步定义了批量迭代计算的逻辑。也是迭代计算(iterative computation)的定义模板。
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
将过上述三步的迭代计算之后,就可以确定下来最终的稳定的簇。那么这一步就开始把每个点归宿给最终的簇了。逻辑还是一样,通过欧式定理来归属。SelectNearestCenter类的实现逻辑在前文中讲过,所以这里不做赘述了。
最后一步就是制定输出结果方式和执行。
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
运行之后的结果是,把每个顶点都归宿到了各个簇去了。
结果如下:
(1,-14.22 -48.01)
(4,-22.78 37.1)
(2,56.18 -42.99)
(3,35.04 50.29)
(1,-9.53 -46.26)
(4,-34.35 48.25)
(2,55.82 -57.49)
(3,21.03 54.64)
(1,-13.63 -42.26)
(4,-36.57 32.63)
(2,50.65 -52.4)
(3,24.48 34.04)
(1,-2.69 -36.02)
(4,-38.8 36.58)
(2,24.0 -53.74)
(3,32.41 24.96)
(1,-4.32 -56.92)
(4,-22.68 29.42)
(2,59.02 -39.56)
(3,24.47 45.07)
(1,5.23 -41.2)
(4,-23.0 38.15)
(2,44.55 -51.5)
(3,14.62 59.06)
(1,7.41 -56.05)
(4,-26.63 28.97)
(2,47.37 -44.72)
(3,29.07 51.06)
(1,0.59 -31.89)
(4,-39.09 20.78)
(2,42.97 -48.98)
(3,34.36 49.08)
(1,-21.91 -49.01)
(4,-46.68 46.04)
(2,48.52 -43.67)
(3,30.05 49.25)
(1,4.03 -43.56)
(4,-37.85 41.72)
(2,38.24 -48.32)
(3,20.83 57.85)
本案例的难点在于迭代计算的应用。机器学习算法的本质就是一个迭代计算,然后在迭代中减少损失函数的不断优化过程。掌握Flink的迭代计算,将为我们设计出更多复杂有效的机器学习模型打下基础。
后续文章中会继续推出,怎么在Flink上实现更多复杂有趣的机器学习模型。