聚类算法 dbscan java 实现

最近总结了一下聚类算法,顺手实现了一下dbscan

DBScan 的思想
   给定半径r 和密度阈值 minPoints ,如果以某个点为中心,半径为r
画圈,若圈内点大于密度阈值,则该点就是核心点,核心点和圈内点
形成一个微簇,迭代选择未访问过的点进行画圈,并且与已有的微簇进
行合并,直到所有的点都被访问过。

聚类算法 dbscan java 实现_第1张图片

package com.yc;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;

public class DBSCAN {
    private static final String fileName = "points.txt";
    private static final String resultFileName = "result.txt";
    private static final double r = 0.05; // 半径
    private static final int minPoints = 5;// 密度阈值
    private static List> microCluster = new CopyOnWriteArrayList<>();

    static class Point {
        private int id; // 点的序列号
        private List location;// 点的坐标
        private int flag; // 0表示未处理,1表示处理了

        Point(int id, List location, int flag) {
            this.id = id;
            this.location = location;
            this.flag = flag;
        }

        public int getId() {
            return id;
        }

        public void setId(int id) {
            this.id = id;
        }

        public List getLocation() {
            return location;
        }

        public void setLocation(List location) {
            this.location = location;
        }

        public int getFlag() {
            return flag;
        }

        public void setFlag(int flag) {
            this.flag = flag;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + flag;
            result = prime * result + id;
            result = prime * result
                    + ((location == null) ? 0 : location.hashCode());
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Point other = (Point) obj;
            if (flag != other.flag)
                return false;
            if (id != other.id)
                return false;
            if (location == null) {
                if (other.location != null)
                    return false;
            } else if (!location.equals(other.location))
                return false;
            return true;
        }

        @Override
        public String toString() {
            return "Point [id=" + id + ", location=" + location + ", flag="
                    + flag + "]";
        }
    }

    public static List readFile() throws Exception {
        BufferedReader file = new BufferedReader(new FileReader(fileName));
        List points = new ArrayList();
        String line = file.readLine();
        int i = 0;
        while (line != null) {
            String p[] = line.split("[;|,|\\s]");
            List location = new ArrayList();
            for (int j = 0; j < p.length; j++) {
                location.add(Double.parseDouble(p[j]));
            }
            points.add(new Point(i++, location, 0));
            line = file.readLine();
        }
        file.close();
        return points;
    }

    public static void saveResult() throws Exception {
        BufferedWriter bw = new BufferedWriter(new FileWriter(resultFileName));
        int i = 1;
        for(List mic : microCluster){
            for(Point p  : mic){
                StringBuffer sb = new StringBuffer();
                for(int j = 0;j",");
                }
                bw.write(sb.append(i).toString());
                bw.newLine();
            }
            i++;
        }
        bw.flush();
        bw.close();
    }

    public static double getDistance(Point point1, Point point2) {
        int wide = point1.location.size(); // 共多少维
        double sum = 0;
        for (int i = 0; i < wide; i++) {
            sum += Math.pow((point1.location.get(i) - point2.location.get(i)), 2);
        }
        return Math.sqrt(sum);
    }

    public static List getDistances(Point point1, List points) {
        List diss = new ArrayList<>();
        int size = points.size();
        for (int i = 0; i < size; i++)
            diss.add(getDistance(point1, points.get(i)));
        return diss;
    }

    public static boolean canCombine(List clu1, List clu2) {
        Set result = new HashSet();
        Set s1 = new HashSet(clu1);
        Set s2 = new HashSet(clu2);
        result.clear();
        result.addAll(s1);
        result.retainAll(s2);
        if (result.size() > 0) {
            return true;
        }
        return false;
    }

    public static void combine(List clu){
        List combine = new ArrayList<>();
        for(int i = 0;i< microCluster.size();i++){
            if(canCombine(clu,microCluster.get(i))) combine.add(i);
        }
        Set com = new HashSet<>();
        List> remove = new ArrayList<>();
        for(int i = 0;i p = microCluster.get(combine.get(i));
            com.addAll(p);
            remove.add(p);
        }
        microCluster.removeAll(remove);
        microCluster.add(new ArrayList<>(com));
    }

    public static void cluster(List points) {
        int size = points.size();
        for (int i = 0; i < size; i++) {
            Point p = points.get(i);
            if (p.flag != 0) continue;
            p.flag = 1;
            List diss = getDistances(p, points);
            List clu = new ArrayList();
            for (int j = 0; j < size; j++) {
                if (diss.get(j) < r)   clu.add(points.get(j)); //若距离大于r 
            }
            if(clu.size() < minPoints) continue;
            else {
                microCluster.add(clu);
                combine(clu);
            }
        }
    }

    public static void main(String[] args) throws Exception {
        long start = System.nanoTime();
        List points = readFile();
        cluster(points);
        saveResult();
        System.out.println("use time:"+(System.nanoTime() - start));
    }
}

你可能感兴趣的:(java,数据挖掘)