3.聚类–K-means的Java实现

K-means的步骤

输入: 含n 个样本的数据集,簇的数据K

输出: K 个簇

算法步骤:

1.初始化K个簇类中心C1,C2,-……Ck (通常随机选择)

2.repeat 步骤3,4

3,将数据集中的每个样本分配到与之最近的中心Ci所在的簇Cj ;

4. 更新聚类中心Ci,即计算各个簇的样本均值;

5.直到样本分配不在改变

上代码:

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 在对象的属性上标注此注释,
 * 表示纳入kmeans算法,仅支持数值类属性
 * @author 阿飞哥
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface KmeanField {
}

 

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
 * 
 * @author 阿飞哥
 * 
 */
public class Kmeans<T> {

    /**
     * 所有数据列表
     */
    private List<T> players = new ArrayList<T>();

    /**
     * 数据类别
     */
    private Class<T> classT;

    /**
     * 初始化列表
     */
    private List<T> initPlayers;

    /**
     * 需要纳入kmeans算法的属性名称
     */
    private List<String> fieldNames = new ArrayList<String>();

    /**
     * 分类数
     */
    private int k = 1;

    public Kmeans() {

    }

    /**
     * 初始化列表
     * 
     * @param list
     * @param k
     */
    public Kmeans(List<T> list, int k) {
        this.players = list;
        this.k = k;
        T t = list.get(0);
        this.classT = (Class<T>) t.getClass();
        Field[] fields = this.classT.getDeclaredFields();
        System.out.println("fields---------------------------------------------="+fields.length);
        for (int i = 0; i < fields.length; i++) {
            Annotation kmeansAnnotation = fields[i]
                    .getAnnotation(KmeanField.class);
            if (kmeansAnnotation != null) {
                fieldNames.add(fields[i].getName());
                System.out.println("fieldNames.add"+ fields[i].getName());
                
            }

        }

        initPlayers = new ArrayList<T>();
        for (int i = 0; i < k; i++) {
            initPlayers.add(players.get(i));
        }
    }

    public List<T>[] comput() {
        List<T>[] results = new ArrayList[k];

        boolean centerchange = true;
        while (centerchange) {
            centerchange = false;
            for (int i = 0; i < k; i++) {
                results[i] = new ArrayList<T>();
            }
            for (int i = 0; i < players.size(); i++) {
                T p = players.get(i);
                double[] dists = new double[k];
                for (int j = 0; j < initPlayers.size(); j++) {
                    T initP = initPlayers.get(j);
                    /* 计算距离 */
                    double dist = distance(initP, p);
//                    double dist = 1.0;
//                    double dist = LevenshteinDistance.levenshteinDistance(initP, p);
//                    System.out.println("dist="+dist);
                
                    dists[j] = dist;
                }

                int dist_index = computOrder(dists);
//                System.out.println("dist_index="+dist_index);
                results[dist_index].add(p);
            }
            
//            System.out.println("results[0].size()="+results[0].size());

            for (int i = 0; i < k; i++) { // 在每一个簇中寻找中心点
                T player_new = findNewCenter(results[i]);
//                System.out.println( "results[i]"+i+"----"+k+"---===="+results[i].size() +"===="+player_new.toString());
                T player_old = initPlayers.get(i);
                if (!IsPlayerEqual(player_new, player_old)) {
                    centerchange = true;
                    initPlayers.set(i, player_new);
                }
            }
        }
//        System.out.println( "results+"+results.length);
        return results;
    }

    /**
     * 比较是否两个对象是否属性一致
     * 
     * @param p1
     * @param p2
     * @return
     */
    public boolean IsPlayerEqual(T p1, T p2) {
        if (p1 == p2) {
            return true;
        }
        if (p1 == null || p2 == null) {
            return false;
        }

        

        boolean flag = true;
        try {
            for (int i = 0; i < fieldNames.size(); i++) {
                
                String fieldName=fieldNames.get(i);
                String getName = "get"
                        + fieldName.substring(0, 1).toUpperCase()
                        + fieldName.substring(1);        
//                System.out.println(fieldNames);
                Object value1 = invokeMethod(p1,getName,null);
                Object value2 = invokeMethod(p2,getName,null);
                if (!value1.equals(value2)) {
                    flag = false;
                    break;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            flag = false;
        }

        return flag;
    }

    /**
     * 得到新聚类中心对象
     * 
     * @param ps
     * @return
     */
    public T findNewCenter(List<T> ps) {
        try {
            T t = classT.newInstance();
            if (ps == null || ps.size() == 0) {
                return t;
            }

            double[] ds = new double[fieldNames.size()];
            for (T vo : ps) {
                for (int i = 0; i < fieldNames.size(); i++) {
                    String fieldName=fieldNames.get(i);
                    String getName = "get"
                            + fieldName.substring(0, 1).toUpperCase()
                            + fieldName.substring(1);
                    Object obj=invokeMethod(vo,getName,null);
                    Double fv=(obj==null?0:Double.parseDouble(obj+""));
                    ds[i] += fv;
                }

            }
            
//            System.out.println("-----------------");
            for (int i = 0; i < fieldNames.size(); i++) {
                ds[i] = ds[i] / ps.size();    // 平均距离
                String fieldName = fieldNames.get(i);
                
                /* 给对象设值 */
                String setName = "set"
                        + fieldName.substring(0, 1).toUpperCase()
                        + fieldName.substring(1);

//                invokeMethod(t,setName,new Class[]{double.class},ds[i]);
                System.out.println("ds[i] ++="+ds[i]+"----ps.size()"+ps.size());
                invokeMethod(t,setName,new Class[]{double.class},ds[i]);

            }
            
            
            
            return t;
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        return null;

    }

    /**
     * 得到最短距离,并返回最短距离索引
     * 
     * @param dists
     * @return
     */
    public int computOrder(double[] dists) {
        double min = 0;
        int index = 0;
        for (int i = 0; i < dists.length - 1; i++) {
            double dist0 = dists[i];
            if (i == 0) {
                min = dist0;
                index = 0;
            }
            double dist1 = dists[i + 1];
            if (min > dist1) {
                min = dist1;
                index = i + 1;
            }
        }

        return index;
    }

    /**
     * 计算距离(相似性) 采用欧几里得算法
     * 
     * @param p0
     * @param p1
     * @return
     */
    public double distance(T p0, T p1) {
        double dis = 0;
        try {

            for (int i = 0; i < fieldNames.size(); i++) {
                String fieldName = fieldNames.get(i);
                String getName = "get"
                        + fieldName.substring(0, 1).toUpperCase()
                        + fieldName.substring(1);
                
//                System.out.println("fieldNames-----="+fieldNames.size());
                Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
                Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
//                System.out.println("field0Value="+field0Value);
                dis += Math.pow(field0Value - field1Value, 2); 
                

                
                
            }
        
        } catch (Exception ex) {
            ex.printStackTrace();
        }
        return Math.sqrt(dis);

    }
    
    /*------公共方法-----*/
    public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
            Object... args) {
        Class ownerClass = owner.getClass();
        
        try {
            Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
            
            return method.invoke(owner, args);
        } catch (SecurityException e) {
            e.printStackTrace();
        } catch (NoSuchMethodException e) {
            e.printStackTrace();
        } catch (Exception ex) {
            ex.printStackTrace();
        }

        return null;
    }

}

public class Player {

private int id;
//@KmeanField
private String name;

private int age;

/* 得分 */
@KmeanField
private double goal;

/* 助攻 */
//@KmeanField
private double assists;

/* 篮板 */
//@KmeanField
private double backboard;

/* 抢断 */
//@KmeanField
private double steals;

public int getId() {
    return id;
}

public void setId(int id) {
    this.id = id;
}

public String getName() {
    return name;
}

public void setName(String name) {
    this.name = name;
}

public int getAge() {
    return age;
}

public void setAge(int age) {
    this.age = age;
}

public double getGoal() {
    return goal;
}

public void setGoal(double goal) {
    this.goal = goal;
}

public double getAssists() {
    return assists;
}

public void setAssists(double assists) {
    this.assists = assists;
}

public double getBackboard() {
    return backboard;
}

public void setBackboard(double backboard) {
    this.backboard = backboard;
}

public double getSteals() {
    return steals;
}

public void setSteals(double steals) {
    this.steals = steals;
}

@Override
    public String toString() {
        // TODO Auto-generated method stub
        return name;
    }
}

 

 
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class TestMain {

    public static void main(String[] args) {
       List<Player> listPlayers=new ArrayList<Player>();
        
        for(int i=0;i<15;i++){
            
            Player p1=new Player();
            p1.setName("afei-"+i);
            p1.setAssists(i);
            p1.setBackboard(i);
            
            //p1.setGoal(new Random(100*i).nextDouble());
            p1.setGoal(i*10);
            p1.setSteals(i);
            //listPlayers.add(p1);    
        }
        
        Player p1=new Player();
        p1.setName("afei1");
        p1.setGoal(1);
        p1.setAssists(8);
        listPlayers.add(p1);
       
        Player p2=new Player();
        p2.setName("afei2");
        p2.setGoal(2);
        listPlayers.add(p2);
        
         Player p3=new Player();
        p3.setName("afei3");
        p3.setGoal(3);
        listPlayers.add(p3);
        
         Player p4=new Player();
        p4.setName("afei4");
        p4.setGoal(7);
        listPlayers.add(p4);
        
         Player p5=new Player();
        p5.setName("afei5");
        p5.setGoal(8);
        listPlayers.add(p5);
        
         Player p6=new Player();
        p6.setName("afei6");
        p6.setGoal(25);
        listPlayers.add(p6);
        
         Player p7=new Player();
        p7.setName("afei7");
        p7.setGoal(26);
        listPlayers.add(p7);
        
         Player p8=new Player();
        p8.setName("afei8");
        p8.setGoal(27);
        listPlayers.add(p8);
        
         Player p9=new Player();
        p9.setName("afei9");
        p9.setGoal(28);
        listPlayers.add(p9);
        
        
        Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,2);
        List<Player>[] results = kmeans.comput();
        for (int i = 0; i < results.length; i++) {
            System.out.println("===========类别" + (i + 1) + "================");
            List<Player> list = results[i];
            for (Player p : list) {
                System.out.println(p.getName() + "--->"
                        + p.getGoal() + "," + p.getAssists() + ","
                        + p.getSteals() + "," + p.getBackboard());
            }
        }
        
        
        
      
    }

}

 

源码:https://github.com/chaoren399/dkdemo/tree/master/kmeans/src

你可能感兴趣的:(3.聚类–K-means的Java实现)