数据挖掘笔记-聚类-KMeans-文档聚类

本篇主要是根据KMediods算法实现文档集的聚类。首先是要将需要聚类的文档进行向量化处理,这里采用的是TFIDF值来表示。文档之间的距离选用的是余弦距离,后面步骤没什么变化。KMediods算法聚类完成之后发现结果不是很理想,后面发现将数据降维后,结果还是比较理想的。


java实现代码如下:

public class DocKMediodsCluster extends AbstractCluster {
	
	//阀值
	public static final double THRESHOLD = 0.028;
	//迭代次数
	public static final int ITER_NUM = 10;
	
	/*
	 * 初始化数据
	 */
	public List<DataPoint> initData() {
		List<DataPoint> dataPoints = new ArrayList<DataPoint>();
		try {
			String path = DocKMediodsCluster.class.getClassLoader().getResource("测试").toURI().getPath();
			DocumentSet documentSet = DocumentLoader.loadDocumentSetByThread(path);
			List<Document> documents = documentSet.getDocuments();
			DocumentUtils.calculateTFIDF_0(documents);
			for(Document document : documents) {
				DataPoint dataPoint = new DataPoint();
				dataPoint.setValues(document.getTfidfWords());
				dataPoint.setCategory(document.getCategory());
				dataPoints.add(dataPoint);
			}
		} catch (URISyntaxException e) {
			e.printStackTrace();
		}
		return dataPoints;
	}
	
	//随机生成中心点,并生成初始的K个聚类
	public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) {
		List<DataPointCluster> clusters = new ArrayList<DataPointCluster>();
		Random random = new Random();
		Set<String> categories = new HashSet<String>();
		while (clusters.size() < k) {
			DataPoint center = points.get(random.nextInt(points.size()));
			String category = center.getCategory();
			if (categories.contains(category)) continue;
			categories.add(category);
			DataPointCluster cluster = new DataPointCluster();
			cluster.setCenter(center);
			cluster.getDataPoints().add(center);
			clusters.add(cluster);
		}
		return clusters;
	}
	
	//将点归入到聚类中
	public void handleCluster(List<DataPoint> points, List<DataPointCluster> clusters, int iterNum) {
		System.out.println("iterNum: " + iterNum);
		for (DataPoint point : points) {
			DataPointCluster maxCluster = null;
			double maxDistance = Integer.MIN_VALUE;
			for (DataPointCluster cluster : clusters) {
				DataPoint center = cluster.getCenter();
				double distance = DistanceUtils.cosine(point.getValues(), center.getValues());
				if (distance > maxDistance) {
					maxDistance = distance;
					maxCluster = cluster;
				}
			}
			if (null != maxCluster) {
				maxCluster.getDataPoints().add(point);
			}
		}
		//终止条件定义为原中心点与新中心点距离小于一定阀值
		//当然也可以定义为原中心点等于新中心点
		boolean flag = true;
		for (DataPointCluster cluster : clusters) {
			DataPoint center = cluster.getCenter();
			DataPoint newCenter = cluster.computeMediodsCenter();
			double distance = DistanceUtils.cosine(
					newCenter.getValues(), center.getValues());
			System.out.println("distaince: " + distance);
			if (distance > THRESHOLD) {
				flag = false;
				cluster.setCenter(newCenter);
			}
		}
		System.out.println("--------------");
		if (!flag && iterNum < ITER_NUM) {
			for (DataPointCluster cluster : clusters) {
				cluster.getDataPoints().clear();
			}
			handleCluster(points, clusters, ++iterNum);
		}
	}
	
	public List<Map.Entry<String, Double>> sortMap(Map<String, Double> map) {
		List<Map.Entry<String, Double>> list = 
				new ArrayList<Map.Entry<String, Double>>(map.entrySet());
		Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
			@Override
			public int compare(Entry<String, Double> o1,
					Entry<String, Double> o2) {
				if (o1.getValue().isNaN()) {
					o1.setValue(0.0);
				}
				if (o2.getValue().isNaN()) {
					o2.setValue(0.0);
				}
				return -o1.getValue().compareTo(o2.getValue());
			}
		});
		return list;
	}
	
	public void build() {
		List<DataPoint> points = initData();
		List<DataPointCluster> clusters = genInitCluster(points, 4);
		for (DataPointCluster cluster : clusters) {
			System.out.println("center: " + cluster.getCenter().getCategory());
		}
		handleCluster(points, clusters, 0);
		int success = 0, failure = 0;
		for (DataPointCluster cluster : clusters) {
			String category = cluster.getCenter().getCategory();
			System.out.println("center: " + category + "--" + cluster.getDataPoints().size());
			for (DataPoint dataPoint : cluster.getDataPoints()) {
				String dpCategory = dataPoint.getCategory();
				System.out.println(dpCategory);
				if (category.equals(dpCategory)) {
					success++;
				} else {
					failure++;
				}
			}
			System.out.println("----------");
		}
		System.out.println("total: " + (success + failure) + " success: " + success + " failure: " + failure);
	}
	
	public static void main(String[] args) {
		new DocKMediodsCluster().build();
	}
	
}

/**
 * 计算TFIDF
 * TF计算是词频除以总词数
 * @param documents
 */
public static void calculateTFIDF_0(List<Document> documents) {
	int docTotalCount = documents.size();
	for (Document document : documents) {
		Map<String, Double> tfidfWords = document.getTfidfWords();
		int wordTotalCount = document.getWords().length;
		Map<String, Integer> docWords = DocumentHelper.wordsInDocStatistics(document);
		for (String word : docWords.keySet()) {
			double wordCount = docWords.get(word);
			double tf = wordCount / wordTotalCount;
			double docCount = DocumentHelper.wordInDocsStatistics(word, documents) + 1;
			double idf = Math.log(docTotalCount / docCount);
			double tfidf = tf * idf;
			tfidfWords.put(word, tfidf);
		}
		System.out.println("doc " + document.getName() + " calculate tfidf finish");
	}
}
	
/**
 * 计算TFIDF
 * TF计算是词频除以词频最高数
 * @param documents
 */
public static void calculateTFIDF_1(List<Document> documents) {
	int docTotalCount = documents.size();
	for (Document document : documents) {
		Map<String, Double> tfidfWords = document.getTfidfWords();
		List<Map.Entry<String, Double>> list = 
				new ArrayList<Map.Entry<String, Double>>(tfidfWords.entrySet());
		Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
			@Override
			public int compare(Entry<String, Double> o1,
					Entry<String, Double> o2) {
				return -o1.getValue().compareTo(o2.getValue());
			}
		});
		if (list.size() == 0) continue; 
		double wordTotalCount = list.get(0).getValue();
		Map<String, Integer> docWords = DocumentHelper.wordsInDocStatistics(document);
		for (String word : docWords.keySet()) {
			double wordCount = docWords.get(word);
			double tf = wordCount / wordTotalCount;
			double docCount = DocumentHelper.wordInDocsStatistics(word, documents) + 1;
			double idf = Math.log(docTotalCount / docCount);
			double tfidf = tf * idf;
			tfidfWords.put(word, tfidf);
		}
		System.out.println("doc " + document.getName() + " calculate tfidf finish");
	}
}



代码托管:https://github.com/fighting-one-piece/repository-datamining.git



你可能感兴趣的:(数据挖掘,聚类,kmeans,文本聚类,特征选择降维)