读代码-KMeansDriver

package org.apache.mahout.clustering.kmeans;
public class KMeansDriver extends AbstractJob
kmeans的入口KMeansDriver类


run函数中buildClusters,clusterData
    Path clustersOut = buildClusters(conf, input, clustersIn, output, measure, maxIterations, delta, runSequential);
    if (runClustering) {
      log.info("Clustering data");
      clusterData(conf,
          input,
          clustersOut,
          new Path(output, AbstractCluster.CLUSTERED_POINTS_DIR),
          measure,
          delta,
          runSequential);
    }


buildClusters函数中提供两种实现
    if (runSequential) {
      return buildClustersSeq(conf, input, clustersIn, output, measure, maxIterations, delta);
    } else {
      return buildClustersMR(conf, input, clustersIn, output, measure, maxIterations, delta);
    }


buildClustersMR实现了迭代更新中心点的过程
    boolean converged = false;
    int iteration = 1;
    while (!converged && iteration <= maxIterations) {
      log.info("K-Means Iteration {}", iteration);
      // point the output to a new directory per iteration
      Path clustersOut = new Path(output, AbstractCluster.CLUSTERS_DIR + iteration);
      converged = runIteration(conf, input, clustersIn, clustersOut, measure.getClass().getName(), delta);
      // now point the input to the old output directory
      clustersIn = clustersOut;
      iteration++;
    }


runIteration函数进入了mapred的核心部分
    job.setMapOutputKeyClass(Text.class);
    job.setMapOutputValueClass(ClusterObservations.class);
    job.setOutputKeyClass(Text.class);
    job.setOutputValueClass(Cluster.class);


输入输出都是sequence file
    job.setInputFormatClass(SequenceFileInputFormat.class);
    job.setOutputFormatClass(SequenceFileOutputFormat.class);
    job.setMapperClass(KMeansMapper.class);
    job.setCombinerClass(KMeansCombiner.class);
    job.setReducerClass(KMeansReducer.class);



package org.apache.mahout.clustering.kmeans;
KMeansMapper类
public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>
//启动类
private KMeansClusterer clusterer;
//用于保存聚类中心
private final Collection<Cluster> clusters = new ArrayList<Cluster>();

setup函数加载了距离度量类,初始化KMeansClusterer,载入聚类中心
      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
      DistanceMeasure measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
          .asSubclass(DistanceMeasure.class).newInstance();
      measure.configure(conf);

      this.clusterer = new KMeansClusterer(measure);

      String clusterPath = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
      if (clusterPath != null && clusterPath.length() > 0) {
        KMeansUtil.configureWithClusterInfo(conf, new Path(clusterPath), clusters);
        if (clusters.isEmpty()) {
          throw new IllegalStateException("No clusters found. Check your -c path.");
        }
      }


map函数中启动
this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, context);


KMeansClusterer类,实现算法的核心类
emitPointToNearestCluster函数中
遍历聚类中心,根据距离找到最近点的聚类中心
输出key:最近聚类中心的标识,value:ClusterObservations对点的封装
ClusterObservations中含有s0:向量计数 s1:向量的累和 s2:向量平方的累和
便于后续计算
    Cluster nearestCluster = null;
    double nearestDistance = Double.MAX_VALUE;
    for (Cluster cluster : clusters) {
      Vector clusterCenter = cluster.getCenter();
      double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
      if (distance < nearestDistance || nearestCluster == null) {
        nearestCluster = cluster;
        nearestDistance = distance;
      }
    }
    context.write(new Text(nearestCluster.getIdentifier()), new ClusterObservations(1, point, point.times(point)));



KMeansCombiner类,对map结果进行汇总
public class KMeansCombiner extends Reducer<Text, ClusterObservations, Text, ClusterObservations>
将同一聚类中心下的向量计数,累和
  @Override
  protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
    throws IOException, InterruptedException {
    Cluster cluster = new Cluster();
    for (ClusterObservations value : values) {
      cluster.observe(value);
    }
    context.write(key, cluster.getObservations());
  }


KMeansReducer类,
public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster>
将同一聚类中心下汇总,计算收敛性,重新计算聚类中心
方法是向量平均值,即所有向量累和除以个数。
输出key:聚类中心标识,value:新聚类中心
  @Override
  protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
    throws IOException, InterruptedException {
    Cluster cluster = clusterMap.get(key.toString());
    for (ClusterObservations delta : values) {
      cluster.observe(delta);
    }
    // force convergence calculation
    boolean converged = clusterer.computeConvergence(cluster, convergenceDelta);
    if (converged) {
      context.getCounter("Clustering", "Converged Clusters").increment(1);
    }
    cluster.computeParameters();
    context.write(new Text(cluster.getIdentifier()), cluster);
  }



clusterData函数中可选择两种实现,单机实现和分布式mapred实现
    if (runSequential) {
      clusterDataSeq(conf, input, clustersIn, output, measure);
    } else {
      clusterDataMR(conf, input, clustersIn, output, measure, convergenceDelta);
    }



clusterDataMR中定义输入输出格式都是sequencefile,输出key为int型,value为vector型
    job.setInputFormatClass(SequenceFileInputFormat.class);
    job.setOutputFormatClass(SequenceFileOutputFormat.class);
    job.setOutputKeyClass(IntWritable.class);
    job.setOutputValueClass(WeightedVectorWritable.class);


只有map作业没有reduce
    job.setMapperClass(KMeansClusterMapper.class);
    job.setNumReduceTasks(0);



KMeansClusterMapper类
public class KMeansClusterMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable>
  private final Collection<Cluster> clusters = new ArrayList<Cluster>();
  private KMeansClusterer clusterer;
根据最终聚类标签,将点加上聚类输出
  @Override
  protected void map(WritableComparable<?> key, VectorWritable point, Context context)
    throws IOException, InterruptedException {
    clusterer.outputPointWithClusterInfo(point.get(), clusters, context);
  }


outputPointWithClusterInfo函数
遍历所有中心,找到最近的,输出
key:聚类id value:WeightedVectorWritable向量

    AbstractCluster nearestCluster = null;
    double nearestDistance = Double.MAX_VALUE;
    for (AbstractCluster cluster : clusters) {
      Vector clusterCenter = cluster.getCenter();
      double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, vector);
      if (distance < nearestDistance || nearestCluster == null) {
        nearestCluster = cluster;
        nearestDistance = distance;
      }
    }
    context.write(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, vector));

你可能感兴趣的:(driver)