1、前面一篇文章算法——K均值聚类算法(Java实现)简单的实现了一下K均值分类算法,这节我们对于他的应用进行一个扩展应用
2、目标为对对象的分类
3、具体实现如下
1)首先建立一个基类KmeansObject,目的为继承该类的子类都可以应用我们的k均值算法进行分类,代码如下
package org.cyxl.util.algorithm; /** * 所有使用k均值分类算法的对象都必须继承自该对象 * @author cyxl * @version 1.0 2012-05-24 * @since 1.0 * */ public class KmeansObject { public float compare; //比较因子 }
package org.cyxl.util.algorithm; import java.util.ArrayList; import java.util.Random; /** * K均值聚类算法 */ public class CommonKmeans { private int k;// 分成多少簇 private int m;// 迭代次数 private int dataSetLength;// 数据集元素个数,即数据集的长度 private ArrayList<KmeansObject> dataSet;// 数据集链表 private ArrayList<KmeansObject> center;// 中心链表 private ArrayList<ArrayList<KmeansObject>> cluster; // 簇 private ArrayList<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小 private Random random; /** * 设置需分组的原始数据集 * * @param dataSet */ public void setDataSet(ArrayList<KmeansObject> dataSet) { this.dataSet = dataSet; } /** * 获取结果分组 * * @return 结果集 */ public ArrayList<ArrayList<KmeansObject>> getCluster() { return cluster; } /** * 构造函数,传入需要分成的簇数量 * * @param k * 簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度 */ public CommonKmeans(int k) { if (k <= 0) { k = 1; } this.k = k; } /** * 初始化 */ private void init() { m = 0; random = new Random(); if (dataSet == null || dataSet.size() == 0) { initDataSet(); } dataSetLength = dataSet.size(); if (k > dataSetLength) { k = dataSetLength; } center = initCenters(); cluster = initCluster(); jc = new ArrayList<Float>(); } /** * 如果调用者未初始化数据集,则采用内部测试数据集 */ private void initDataSet() { dataSet = new ArrayList<KmeansObject>(); for(int i=0;i<10;i++) { int temp = random.nextInt(100); KmeansObject ko=new KmeansObject(); ko.compare=temp; dataSet.add(ko); } } /** * 初始化中心数据链表,分成多少簇就有多少个中心点 * * @return 中心点集 */ private ArrayList<KmeansObject> initCenters() { ArrayList<KmeansObject> center = new ArrayList<KmeansObject>(); int[] randoms = new int[k]; boolean flag; int temp = random.nextInt(dataSetLength); randoms[0] = temp; for (int i = 1; i < k; i++) { flag = true; while (flag) { temp = random.nextInt(dataSetLength); int j = 0; // 不清楚for循环导致j无法加1 // for(j=0;j<i;++j) // { // if(temp==randoms[j]); // { // break; // } // } while (j < i) { if (temp == randoms[j]) { break; } j++; } if (j == i) { flag = false; } } randoms[i] = temp; } for (int i = 0; i < k; i++) { center.add(dataSet.get(randoms[i]));// 生成初始化中心链表 } return center; } /** * 初始化簇集合 * * @return 一个分为k簇的空数据的簇集合 */ private ArrayList<ArrayList<KmeansObject>> initCluster() { ArrayList<ArrayList<KmeansObject>> cluster = new ArrayList<ArrayList<KmeansObject>>(); for (int i = 0; i < k; i++) { cluster.add(new ArrayList<KmeansObject>()); } return cluster; } /** * 计算两个点之间的距离 * * @param element * 点1 * @param center * 点2 * @return 距离 */ private float distance(KmeansObject element, KmeansObject center) { float distance = 0.0f; distance=Math.abs(element.compare-center.compare); return distance; } /** * 获取距离集合中最小距离的位置 * * @param distance * 距离数组 * @return 最小距离在距离数组中的位置 */ private int minDistance(float[] distance) { float minDistance = distance[0]; int minLocation = 0; for (int i = 1; i < distance.length; i++) { if (distance[i] < minDistance) { minDistance = distance[i]; minLocation = i; } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置 { if (random.nextInt(10) < 5) { minLocation = i; } } } return minLocation; } /** * 核心,将当前元素放到最小距离中心相关的簇中 */ private void clusterSet() { float[] distance = new float[k]; for (int i = 0; i < dataSetLength; i++) { for (int j = 0; j < k; j++) { distance[j] = distance(dataSet.get(i), center.get(j)); } int minLocation = minDistance(distance); cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中 } } /** * 求两点误差平方的方法 * * @param element * 点1 * @param center * 点2 * @return 误差平方 */ private float errorSquare(KmeansObject element, KmeansObject center) { float x = Math.abs(element.compare-center.compare); float errSquare = x * x; return errSquare; } /** * 计算误差平方和准则函数方法 */ private void countRule() { float jcF = 0; for (int i = 0; i < cluster.size(); i++) { for (int j = 0; j < cluster.get(i).size(); j++) { jcF += errorSquare(cluster.get(i).get(j), center.get(i)); } } jc.add(jcF); } /** * 设置新的簇中心方法 */ private void setNewCenter() { for (int i = 0; i < k; i++) { int n = cluster.get(i).size(); if (n != 0) { KmeansObject newCenter = new KmeansObject(); for (int j = 0; j < n; j++) { newCenter.compare += cluster.get(i).get(j).compare; } // 设置一个平均值 newCenter.compare=newCenter.compare/n; center.set(i, newCenter); } } } /** * 打印数据,测试用 * * @param dataArray * 数据集 * @param dataArrayName * 数据集名称 */ public void printDataArray(ArrayList<KmeansObject> dataArray, String dataArrayName) { for (int i = 0; i < dataArray.size(); i++) { System.out.println("print:" + dataArrayName + "[" + i + "]={" + dataArray.get(i) + "}"); } System.out.println("==================================="); } /** * Kmeans算法核心过程方法 */ private void kmeans() { init(); // 循环分组,直到误差不变为止 while (true) { clusterSet(); countRule(); // 误差不变了,分组完成 if (m != 0) { if (jc.get(m) - jc.get(m - 1) == 0) { break; } } setNewCenter(); m++; cluster.clear(); cluster = initCluster(); } } /** * 执行算法 */ public void execute() { long startTime = System.currentTimeMillis(); System.out.println("kmeans begins"); kmeans(); long endTime = System.currentTimeMillis(); System.out.println("kmeans running time=" + (endTime - startTime) + "ms"); System.out.println("kmeans ends"); System.out.println(); } }
package org.cyxl.util.algorithm; public class Person extends KmeansObject { String name=""; int age=0; float qz=1; //权重 public Person(){} public Person(String name,int age,float qz) { this.name=name; this.age=age; this.qz=qz; } public String getName() { return name; } public void setName(String name) { this.name = name; } public int getAge() { return age; } public void setAge(int age) { this.age = age; } public float getQz() { return qz; } public void setQz(float qz) { this.qz = qz; } public String toString() { return "name:"+this.name+";age:"+this.age+";qz:"+this.qz+";compare:"+super.compare; } }
CommonKmeans k=new CommonKmeans(5); ArrayList<KmeansObject> list=new ArrayList<KmeansObject>(); for(int i=0;i<10;i++) { float qz=(float)(new Random().nextInt(10))/10; Person p=new Person("name"+i,i,qz); p.compare=new Random().nextInt(100)*p.getQz(); list.add(p); } k.setDataSet(list); k.printDataArray(k.dataSet, "before"); k.execute(); ArrayList<ArrayList<KmeansObject>> cluster=k.getCluster(); //查看结果 for(int i=0;i<cluster.size();i++) { k.printDataArray(cluster.get(i), "cluster["+i+"]"); }
print:before[0]={name:name0;age:0;qz:0.0;compare:0.0} print:before[1]={name:name1;age:1;qz:0.9;compare:48.6} print:before[2]={name:name2;age:2;qz:0.9;compare:57.6} print:before[3]={name:name3;age:3;qz:0.4;compare:28.4} print:before[4]={name:name4;age:4;qz:0.0;compare:0.0} print:before[5]={name:name5;age:5;qz:0.4;compare:33.600002} print:before[6]={name:name6;age:6;qz:0.5;compare:2.0} print:before[7]={name:name7;age:7;qz:0.2;compare:14.6} print:before[8]={name:name8;age:8;qz:0.6;compare:5.4} print:before[9]={name:name9;age:9;qz:0.9;compare:52.199997} =================================== kmeans begins kmeans running time=0ms kmeans ends print:cluster[0][0]={name:name3;age:3;qz:0.4;compare:28.4} print:cluster[0][1]={name:name5;age:5;qz:0.4;compare:33.600002} =================================== print:cluster[1][0]={name:name7;age:7;qz:0.2;compare:14.6} =================================== print:cluster[2][0]={name:name2;age:2;qz:0.9;compare:57.6} =================================== print:cluster[3][0]={name:name1;age:1;qz:0.9;compare:48.6} print:cluster[3][1]={name:name9;age:9;qz:0.9;compare:52.199997} =================================== print:cluster[4][0]={name:name0;age:0;qz:0.0;compare:0.0} print:cluster[4][1]={name:name4;age:4;qz:0.0;compare:0.0} print:cluster[4][2]={name:name6;age:6;qz:0.5;compare:2.0} print:cluster[4][3]={name:name8;age:8;qz:0.6;compare:5.4} ===================================
1)基类KmeansObject定义了一个compare,我们把它叫做比较因子,分类时只要就是对分类因子进行分类计算的。所以这个分类因子很重要,每个对象的分类因子可以具体的根据业务进行计算设置。比如我们客户端测试代码中的比较因子的计算方法是,首先给每个对象赋予一个权值qz,然后根据权值和年龄的乘积(具体计算方法根据业务定)来对人群进行分类
2)该算法中对于比较因子compare的计算是影响该算法准确性的一个很重要方面,具体表现在距离(distance方法)和误差(errorSquare方法)计算中。想要改善该算法可以从这两个方法中进行修改
3)当然,我对于这个算法的实现和应用都还是很浅。如果有什么不对或者可以改善的地方请不吝赐教