mahout 源码解析之聚类--聚类迭代模型

在前面讲聚类策略时,包org.apache.mahout.clustering.iterator里面还有几个类没有进行讲解,这次做下收尾工作。

ClusterIterator利用ClusterClassifier和指定的迭代次数将样本进行聚类。其中有三个具体的函数。

iterate主要对内存中的数据进行聚类,输入就为一个Vector类型的迭代器。

	public ClusterClassifier iterate(Iterable data,
			ClusterClassifier classifier, int numIterations) {
		ClusteringPolicy policy = classifier.getPolicy();
		for (int iteration = 1; iteration <= numIterations; iteration++) {
			for (Vector vector : data) {
				// 根据当前的classifier更新聚类策略
				policy.update(classifier);
				// 利用classification里面的聚类策略对样本计算样本分类概率
				Vector probabilities = classifier.classify(vector);
				// 根据样本分布概率计算权重
				Vector weights = policy.select(probabilities);
				// training causes all models to observe data
				for (Iterator it = weights.iterateNonZero(); it
						.hasNext();) {
					int index = it.next().index();
					classifier.train(index, vector, weights.get(index));
				}
			}
			// 更新分类器
			classifier.close();
		}
		return classifier;
	}

iterateSeq主要对SequenceFile类型的文件(里面存放的是样本向量)进行聚类。

	public void iterateSeq(Configuration conf, Path inPath, Path priorPath,
			Path outPath, int numIterations) throws IOException {
		// 从文件中恢复聚类分类模型
		ClusterClassifier classifier = new ClusterClassifier();
		classifier.readFromSeqFiles(conf, priorPath);
		
		Path clustersOut = null;
		int iteration = 1;
		while (iteration <= numIterations) {
			for (VectorWritable vw : new SequenceFileDirValueIterable(
					inPath, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
				// 获取每一个向量
				Vector vector = vw.get();
				// 获取样本分类概率
				Vector probabilities = classifier.classify(vector);
				// 根据概率计算每个model的权重
				Vector weights = classifier.getPolicy().select(probabilities);
				// 根据输入样本训练分类器
				for (Iterator it = weights.iterateNonZero(); it
						.hasNext();) {
					int index = it.next().index();
					classifier.train(index, vector, weights.get(index));
				}
			}
			// 计算后验模型
			classifier.close();
			// 更新聚类策略
			classifier.getPolicy().update(classifier);
			// 输出分类模型
			clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
			classifier.writeToSeqFiles(clustersOut);
			FileSystem fs = FileSystem.get(outPath.toUri(), conf);
			iteration++;
			if (isConverged(clustersOut, conf, fs)) {
				//融合后就不再迭代
				break;
			}
		}
		// 倒数第二次的结果为最终结果
		Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR
				+ (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
		FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut,
				finalClustersIn);
	}
iterateMR为利用Map-Reduce版本的算法,Map为CIMapper,Reduce为CIReducer。

我们先来看CIMapper,在setup阶段,先中文件中反序列化ClusterClassifier,并利用ClusterClassifier更新聚类策略

	protected void setup(Context context) throws IOException,
			InterruptedException {
		Configuration conf = context.getConfiguration();
		String priorClustersPath = conf.get(ClusterIterator.PRIOR_PATH_KEY);
		classifier = new ClusterClassifier();
		//从文件中读取ClusterClassifier
		classifier.readFromSeqFiles(conf, new Path(priorClustersPath));
		// 获取聚类策略,并利用ClusterClassifier更新聚类策略
		policy = classifier.getPolicy();
		policy.update(classifier);
		super.setup(context);
	}

在map阶段,先计算聚类分类模型对每个样本的分类概率,并利用权重不为0的模型训练聚类分类器

	protected void map(WritableComparable key, VectorWritable value,
			Context context) throws IOException, InterruptedException {
		// 计算分类聚类模型对每个样本的分类结果
		Vector probabilities = classifier.classify(value.get());
		Vector selections = policy.select(probabilities);
		// 将权重不为o的模型用于训练
		for (Iterator it = selections.iterateNonZero(); it.hasNext();) {
			Element el = it.next();
			classifier.train(el.index(), value.get(), el.get());
		}
	}

cleanup阶段就将聚类分类模型中的聚类簇输出

	protected void cleanup(Context context) throws IOException,
			InterruptedException {
		List clusters = classifier.getModels();
		ClusterWritable cw = new ClusterWritable();
		for (int index = 0; index < clusters.size(); index++) {
			cw.setValue(clusters.get(index));
			context.write(new IntWritable(index), cw);
		}
		super.cleanup(context);
	}

在CIReducer中我们需要合并簇,CIMapper中Map的输出为key-->簇编号,value-->ClusterWritable,在CIReducer中主要是将相同簇编号的簇合并

	protected void reduce(IntWritable key, Iterable values,
			Context context) throws IOException, InterruptedException {
		Iterator iter = values.iterator();
		ClusterWritable first = null;
		while (iter.hasNext()) {
			ClusterWritable cw = iter.next();
			if (first == null) {
				first = cw;
			} else {
				first.getValue().observe(cw.getValue());
			}
		}
		List models = new ArrayList();
		models.add(first.getValue());
		classifier = new ClusterClassifier(models, policy);
		classifier.close();
		
		context.write(key, first);
	}

合并的时候使用的函数observe,具体参见聚类模型那一篇。

你可能感兴趣的:(数据挖掘,mahout源码,mahout源码解析)