聚类算法 kmeans java 实现

下图为k-means 算法指定2 个聚类中心,迭代十次得到的结果

聚类算法 kmeans java 实现_第1张图片

package com.yc;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;

public class KMeans {
    private static final String fileName = "points.txt";
    private static final String resultFileName = "result.txt";
    private static final int k = 2;
    private static final int n = 10;
    private static Location[] center = new Location[k];
    private static List> microCluster = new CopyOnWriteArrayList<>();

    static class Location {
        private List location;// 点的坐标

        public Location(List location) {
            this.location = location;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result
                    + ((location == null) ? 0 : location.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Location other = (Location) obj;
            if (location == null) {
                if (other.location != null)
                    return false;
            } else if (!location.equals(other.location))
                return false;
            return true;
        }

        @Override
        public String toString() {
            return "Location [location=" + location + "]";
        }

    }

    static class Point {
        private int id; // 点的序列号
        private List location;// 点的坐标
        private int flag; // 0表示未处理,1表示处理了

        Point(int id, List location, int flag) {
            this.id = id;
            this.location = location;
            this.flag = flag;
        }

        public int getId() {
            return id;
        }

        public void setId(int id) {
            this.id = id;
        }

        public List getLocation() {
            return location;
        }

        public void setLocation(List location) {
            this.location = location;
        }

        public int getFlag() {
            return flag;
        }

        public void setFlag(int flag) {
            this.flag = flag;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + flag;
            result = prime * result + id;
            result = prime * result
                    + ((location == null) ? 0 : location.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Point other = (Point) obj;
            if (flag != other.flag)
                return false;
            if (id != other.id)
                return false;
            if (location == null) {
                if (other.location != null)
                    return false;
            } else if (!location.equals(other.location))
                return false;
            return true;
        }

        @Override
        public String toString() {
            return "Point [id=" + id + ", location=" + location + ", flag="
                    + flag + "]";
        }
    }

    public static List readFile() throws Exception {
        BufferedReader file = new BufferedReader(new FileReader(fileName));
        List points = new ArrayList();
        String line = file.readLine();
        int i = 0;
        while (line != null) {
            String p[] = line.split("[;|,|\\s]");
            List location = new ArrayList();
            for (int j = 0; j < p.length; j++) {
                location.add(Double.parseDouble(p[j]));
            }
            points.add(new Point(i++, location, 0));
            line = file.readLine();
        }
        file.close();
        return points;
    }

    public static void saveResult() throws Exception {
        BufferedWriter bw = new BufferedWriter(new FileWriter(resultFileName));
        int i = 1;
        for (List mic : microCluster) {
            for (Point p : mic) {
                StringBuffer sb = new StringBuffer();
                for (int j = 0; j < p.location.size(); j++) {
                    sb.append(p.location.get(j) + ",");
                }
                bw.write(sb.append(i).toString());
                bw.newLine();
            }
            i++;
        }
        bw.flush();
        bw.close();
    }

    public static void randomSet(int min, int max, int n, Set set) {
        if (n > (max - min + 1) || max < min) {
            return;
        }
        for (int i = 0; i < n; i++) {
            // 调用Math.random()方法
            int num = (int) (Math.random() * (max - min)) + min;
            set.add(num);// 将不同的数存入HashSet中
        }
        int setSize = set.size();
        // 如果存入的数小于指定生成的个数,则调用递归再生成剩余个数的随机数,如此循环,直到达到指定大小
        if (setSize < n) {
            randomSet(min, max, n - setSize, set);// 递归
        }
    }

    public static Location getNewCenter(List points) {
        List sum = new ArrayList<>();
        for (Point p : points) {
            int size = p.location.size();
            for (int i = 0; i < size; i++) {
                sum.add(0.0);
                if (sum.get(i) == 0.0) sum.set(i,p.location.get(i));
                else sum.set(i, sum.get(i) + p.location.get(i));
            }
        }
        List loc = new ArrayList<>();
        for (Double s : sum)
            loc.add(s / points.size());
        return new Location(loc);
    }

    public static double getDistance(Point point1, Location loc) {
        int wide = point1.location.size(); // 共多少维
        double sum = 0;
        for (int i = 0; i < wide; i++) {

            try {
                sum += Math.pow((point1.location.get(i) - loc.location.get(i)), 2);
            } catch (Exception e) {
            //  System.out.println("i="+i+"  pp :"+point1.location+"loc:"+loc.location);
            }
        }
        return Math.sqrt(sum);
    }

    public static void micCluster(List points){
        microCluster = new ArrayList<>();
        for(int i=0;inew ArrayList<>());
        }
        for (Point p : points) {
            List dis = new ArrayList<>();
            for (int i = 0; i < k; i++)
                dis.add(getDistance(p, center[i]));
            int min = getMin(dis);
            List po = microCluster.get(min);
            if(po == null){
                List pp = new ArrayList<>();
                pp.add(p);
                microCluster.set(min,pp );
            }
            else {
                po.add(p);
                microCluster.set(min,po );
            }
        }
    }

    public static void cluster(List points) {
        micCluster(points);
        for(int i =0;ifor(int j = 0;j//新质心
            }
            micCluster(points);
        }

    }

    public static int getMin(List dis) {
        if(dis.size() == 1) return 0;
        double min = dis.get(0);
        int loc = 0;
        for(int i = 0;iif(dis.get(i) < min){
                min = dis.get(i);
                loc = i;
            }
        }
        //System.out.println(dis+""+loc);
        return loc;
    }

    public static void main(String[] args) throws Exception {
        long start = System.nanoTime();
        List points = readFile();
        Set centers = new HashSet<>();
        randomSet(0, points.size(), k, centers);
        for (int i = 0; i < k; i++){
            center[i] = new Location(points
                    .get(new ArrayList<>(centers).get(i)).getLocation());
        }
        cluster(points);
        saveResult();
        System.out.println("use time:" + (System.nanoTime() - start));
    }

}

你可能感兴趣的:(java,数据挖掘)