package org.apache.mahout.clustering.kmeans;
public class KMeansDriver extends AbstractJob
kmeans的入口KMeansDriver类
run函数中buildClusters,clusterData
Path clustersOut = buildClusters(conf, input, clustersIn, output, measure, maxIterations, delta, runSequential);
if (runClustering) {
log.info("Clustering data");
clusterData(conf,
input,
clustersOut,
new Path(output, AbstractCluster.CLUSTERED_POINTS_DIR),
measure,
delta,
runSequential);
}
buildClusters函数中提供两种实现
if (runSequential) {
return buildClustersSeq(conf, input, clustersIn, output, measure, maxIterations, delta);
} else {
return buildClustersMR(conf, input, clustersIn, output, measure, maxIterations, delta);
}
buildClustersMR实现了迭代更新中心点的过程
boolean converged = false;
int iteration = 1;
while (!converged && iteration <= maxIterations) {
log.info("K-Means Iteration {}", iteration);
// point the output to a new directory per iteration
Path clustersOut = new Path(output, AbstractCluster.CLUSTERS_DIR + iteration);
converged = runIteration(conf, input, clustersIn, clustersOut, measure.getClass().getName(), delta);
// now point the input to the old output directory
clustersIn = clustersOut;
iteration++;
}
runIteration函数进入了mapred的核心部分
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(ClusterObservations.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Cluster.class);
输入输出都是sequence file
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapperClass(KMeansMapper.class);
job.setCombinerClass(KMeansCombiner.class);
job.setReducerClass(KMeansReducer.class);
package org.apache.mahout.clustering.kmeans;
KMeansMapper类
public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>
//启动类
private KMeansClusterer clusterer;
//用于保存聚类中心
private final Collection<Cluster> clusters = new ArrayList<Cluster>();
setup函数加载了距离度量类,初始化KMeansClusterer,载入聚类中心
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
DistanceMeasure measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
.asSubclass(DistanceMeasure.class).newInstance();
measure.configure(conf);
this.clusterer = new KMeansClusterer(measure);
String clusterPath = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
if (clusterPath != null && clusterPath.length() > 0) {
KMeansUtil.configureWithClusterInfo(conf, new Path(clusterPath), clusters);
if (clusters.isEmpty()) {
throw new IllegalStateException("No clusters found. Check your -c path.");
}
}
map函数中启动
this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, context);
KMeansClusterer类,实现算法的核心类
emitPointToNearestCluster函数中
遍历聚类中心,根据距离找到最近点的聚类中心
输出key:最近聚类中心的标识,value:ClusterObservations对点的封装
ClusterObservations中含有s0:向量计数 s1:向量的累和 s2:向量平方的累和
便于后续计算
Cluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (Cluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
if (distance < nearestDistance || nearestCluster == null) {
nearestCluster = cluster;
nearestDistance = distance;
}
}
context.write(new Text(nearestCluster.getIdentifier()), new ClusterObservations(1, point, point.times(point)));
KMeansCombiner类,对map结果进行汇总
public class KMeansCombiner extends Reducer<Text, ClusterObservations, Text, ClusterObservations>
将同一聚类中心下的向量计数,累和
@Override
protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
throws IOException, InterruptedException {
Cluster cluster = new Cluster();
for (ClusterObservations value : values) {
cluster.observe(value);
}
context.write(key, cluster.getObservations());
}
KMeansReducer类,
public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster>
将同一聚类中心下汇总,计算收敛性,重新计算聚类中心
方法是向量平均值,即所有向量累和除以个数。
输出key:聚类中心标识,value:新聚类中心
@Override
protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
throws IOException, InterruptedException {
Cluster cluster = clusterMap.get(key.toString());
for (ClusterObservations delta : values) {
cluster.observe(delta);
}
// force convergence calculation
boolean converged = clusterer.computeConvergence(cluster, convergenceDelta);
if (converged) {
context.getCounter("Clustering", "Converged Clusters").increment(1);
}
cluster.computeParameters();
context.write(new Text(cluster.getIdentifier()), cluster);
}
clusterData函数中可选择两种实现,单机实现和分布式mapred实现
if (runSequential) {
clusterDataSeq(conf, input, clustersIn, output, measure);
} else {
clusterDataMR(conf, input, clustersIn, output, measure, convergenceDelta);
}
clusterDataMR中定义输入输出格式都是sequencefile,输出key为int型,value为vector型
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setOutputKeyClass(IntWritable.class);
job.setOutputValueClass(WeightedVectorWritable.class);
只有map作业没有reduce
job.setMapperClass(KMeansClusterMapper.class);
job.setNumReduceTasks(0);
KMeansClusterMapper类
public class KMeansClusterMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable>
private final Collection<Cluster> clusters = new ArrayList<Cluster>();
private KMeansClusterer clusterer;
根据最终聚类标签,将点加上聚类输出
@Override
protected void map(WritableComparable<?> key, VectorWritable point, Context context)
throws IOException, InterruptedException {
clusterer.outputPointWithClusterInfo(point.get(), clusters, context);
}
outputPointWithClusterInfo函数
遍历所有中心,找到最近的,输出
key:聚类id value:WeightedVectorWritable向量
AbstractCluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (AbstractCluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, vector);
if (distance < nearestDistance || nearestCluster == null) {
nearestCluster = cluster;
nearestDistance = distance;
}
}
context.write(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, vector));