Java实现K-Means聚类算法

K-means算法基本思想

在数据集中根据一定策略选择K个点作为每个簇的初始中心,将数据划分到距离这K个点最近的簇中,共分成K个类。也就是说将数据划分成K个簇完成一次划分,但形成的新簇并不一定是最好的划分,因此生成的新簇中,重新计算每个簇的中心点,然后再重新进行划分,直到每次划分的结果保持不变。

算法步骤

  • 随机选择K个中心点
  • 把每个数据点分配到离它最近的中心点(此处的距离采用欧氏距离)
  • 重新计算每类中的点到该类中心点距离的平均值
  • 分配每个数据到它最近的中心点
  • 重复步骤3和4,直到每个类别中的数据不再发生变化。

Java实现K-means聚类算法

现有若干鸢尾花的数据,每朵鸢尾花有4个数据,分别为萼片长(单位:厘米)、萼片宽(单位厘米)、花瓣长度(单位厘米)和花瓣宽(单位厘米)。我们希望能找到可行的方法可以按每朵花的4个数据的差异将这些鸢尾花分成若干类,让每一类尽可能的准确,以便帮助植物专家对这些花进行进一步的分析。编程实现K-Means聚类算法,将鸢尾花分类成3类。

数据集样本如下:
Java实现K-Means聚类算法_第1张图片
先将以上数据写入文件,文件中的内容如下:
Java实现K-Means聚类算法_第2张图片
在这里插入图片描述
程序运行结果:
由于数据太长,只截取了一部分
Java实现K-Means聚类算法_第3张图片
Java实现K-Means聚类算法_第4张图片

从以上结果可以看到,第10次迭代后产生的分类结果和第9次完全相同,故分类完成,共迭代10次,算法结束。最后一次的迭代结果即为最终的分类结果。

代码如下

package d;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class Kmeans {
     
	//记录迭代的次数
	static int count = 1;
	//文件所在路径
	static String filePath = System.getProperty("user.dir")+"\\src\\d\\Iris.txt";
	//储存从文件中读取的数据
	static ArrayList<ArrayList<Float>> table = new ArrayList<ArrayList<Float>>();
	//储存分类一的结果
	static ArrayList<ArrayList<Float>> alist = new ArrayList<ArrayList<Float>>();
	//储存分类二的结果
	static ArrayList<ArrayList<Float>> blist = new ArrayList<ArrayList<Float>>();
	//储存分类三的结果
	static ArrayList<ArrayList<Float>> clist = new ArrayList<ArrayList<Float>>();
	//记录初始随机产生的3个聚类中心
	static ArrayList<ArrayList<Float>> randomList = new ArrayList<ArrayList<Float>>();
	
	//读取文件中的数据,储存到集合中
	public static ArrayList<ArrayList<Float>> readTable(String filePath){
     
		ArrayList<Float> d = null;
		File file = new File(filePath);
		try {
     
			InputStreamReader isr = new InputStreamReader(new FileInputStream(file));
			BufferedReader bf = new BufferedReader(isr);
			String str = null;
			while((str = bf.readLine()) != null) {
     
				d = new ArrayList<Float>();
				String[] str1 = str.split(",");
				for(int i = 0; i < str1.length ; i++) {
     
					d.add(Float.parseFloat(str1[i]));
				}
				table.add(d);
			}
//			System.out.println(table);
			bf.close();
			isr.close();
		} catch (Exception e) {
     
			e.printStackTrace();
			System.out.println("文件不存在!");
		}
		return table;
	}
	
	//随机产生3个初始聚类中心
	public static ArrayList<ArrayList<Float>> randomList() {
     
		int[] list = new int[3];
		//产生3个互不相同的随机数
		do {
     
			list[0] = (int)(Math.random()*30);
			list[1] = (int)(Math.random()*30);
			list[2] = (int)(Math.random()*30);
		}while(list[0] == list[1] && list[0] == list[2] && list[1] == list[2]);
//		System.out.println("索引:"+list[0]+" "+list[1]+" "+list[2]);
//为了测试方便,我这里去数据集中前3个作为初始聚类中心
		for(int i = 0; i < 3 ; i++) {
     
			//randomList.add(list[i]);
			randomList.add(table.get(i));
		 }
		return randomList;
	}
	
	//比较两个数的大小,并返回其中较小的数
	public static double minNumber(double x, double y) {
     
		if(x < y) {
     
			return x;
		}
		return y;
	}
	
	//计算各个数据到三个中心点的距离,然后分成三类
	public static void eudistance(ArrayList<ArrayList<Float>> list){
     
		alist.clear();
		blist.clear();
		clist.clear();
		double minNumber;
		double distancea,distanceb,distancec;
//		System.out.println("randomList:"+randomList);
		for(int i = 0; i < table.size() ; i++) {
     
			distancea = Math.pow(table.get(i).get(1)-list.get(0).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(0).get(2), 2) + 
					Math.pow(table.get(i).get(3)-list.get(0).get(3), 2) + 
					Math.pow(table.get(i).get(4)-list.get(0).get(4), 2);
			distanceb = Math.pow(table.get(i).get(1)-list.get(1).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(1).get(2), 2) +
					Math.pow(table.get(i).get(3)-list.get(1).get(3), 2) +
					Math.pow(table.get(i).get(4)-list.get(1).get(4), 2);
			distancec = Math.pow(table.get(i).get(1)-list.get(2).get(1), 2) +
					Math.pow(table.get(i).get(2)-list.get(2).get(2), 2) +
					Math.pow(table.get(i).get(3)-list.get(2).get(3), 2) +
					Math.pow(table.get(i).get(4)-list.get(2).get(4), 2);
			minNumber = minNumber(minNumber(distancea,distanceb),distancec);
			if(minNumber == distancea) {
     
				alist.add(table.get(i));
			}else if(minNumber == distanceb) {
     
				blist.add(table.get(i));
			}else {
     
				clist.add(table.get(i));
			}
		 }
		System.out.println("第"+count+"次迭代:");
		System.out.println(alist);
		System.out.println(blist);
		System.out.println(clist);
		System.out.println("\n");
		count++;
	}
	
	//计算每个类中四个数据的平均值,重新生成聚类中心
	public static ArrayList<Float> newList(ArrayList<ArrayList<Float>> list){
     
		float avnum1,avnum2,avnum3,avnum4,c=0f;
		float sum1 = 0,sum2 = 0,sum3 = 0,sum4 = 0;
		ArrayList<Float> k = new ArrayList<Float>();
		for(int i = 0; i < list.size(); i++) {
     
			sum1 += list.get(i).get(1);
			sum2 += list.get(i).get(2);
			sum3 += list.get(i).get(3);
			sum4 += list.get(i).get(4);
		}
		avnum1 = (float)(sum1*1.0 / list.size());
		avnum2 = (float)(sum2*1.0 / list.size());
		avnum3 = (float)(sum3*1.0 / list.size());
		avnum4 = (float)(sum4*1.0 / list.size());
		k.add(c);
		k.add(avnum1);
		k.add(avnum2);
		k.add(avnum3);
		k.add(avnum4);
		return k;
	}
	
	//判断两个集合的元素是否完全相同,若相同,则返回1;否则,返回0
	public static int same(ArrayList<ArrayList<Float>> list1, ArrayList<ArrayList<Float>> list2) {
     
		int countn = 0;
		if(list1.size()==list2.size()) {
     
			for(int i = 0; i < list1.size() ; i++) {
     
				for(int j = 0; j < list2.size() ; j++) {
     
					if(list1.get(i).containsAll(list2.get(j)) && list2.get(j).containsAll(list1.get(i))) {
     
						countn++;
						break;
					}
				}
			}
		}
		if(countn == list1.size()) {
     
			return 1;
		}else {
     
			return 0;
		}
	}
	
	//迭代求出最后的分类结果
	public static void kmeans() {
     
		int a,b,c,k=0;
		ArrayList<ArrayList<Float>> klist = null;
		ArrayList<ArrayList<Float>> arlist = null;
		ArrayList<ArrayList<Float>> brlist = null;
		ArrayList<ArrayList<Float>> crlist = null;
		do {
     
			klist = new ArrayList<ArrayList<Float>>();
			arlist = new ArrayList<ArrayList<Float>>();
			brlist = new ArrayList<ArrayList<Float>>();
			crlist = new ArrayList<ArrayList<Float>>();
			arlist.addAll(alist);
			brlist.addAll(blist);
			crlist.addAll(clist);
			klist.clear();
			klist.add(newList(alist));
			klist.add(newList(blist));
			klist.add(newList(clist));
			eudistance(klist);
			a = same(alist,arlist);
			b = same(blist,brlist);
			c = same(clist,crlist);
			if(a == 1 && b == 1 && c == 1) {
     
				Kmeans.count = 1;
				break;
			}
		}while(true);
	}

	public static void main(String[] args) {
     
		ArrayList<ArrayList<Float>> rlist = new ArrayList<ArrayList<Float>>();
		readTable(filePath);
		rlist = randomList();
		eudistance(rlist);
		kmeans();
	}
}

你可能感兴趣的:(数据挖掘)