FuzzyKmeans聚类JAVA版本实现

在对数据进行聚类时,最常用的方法应该是kmeans,但是kmean只能保证每一条待聚类的数据划分到一个类别,针对一条数据可以被划分到多个类别的情况无法处理。为此,人们提出了FuzzyKmeans聚类方法,该方法衡量的是每一条数据属于某个类别的概率,既然是概率就不再是非1即0的情况,这样就能保证一条数据可以被划分到多个类别。

对应FuzzyKmeans的聚类过程如下:

FuzzyKmeans聚类JAVA版本实现_第1张图片


其中dij这个参数衡量的是该条数据i到类别j中心点的距离,uij就是数据i属于类别j的概率。

求得概率之后,需要更新某个类别的中心点,这时就按照(4)式更新,也就是用属于该类的概率与数据原先的值加以计算

至于结束条件一种是达到设定的迭代次数,一种是满足第四步的条件,即两个类别的中心点距离小于一个值。

最重要的应该是m值的选择,当每条数据距离各个类别中心点距离比较接近时,建议1/(m-1)值较大,因为这样在指数运算后距离就能有较大差异了,此时m接近于1. 如果距离本来就有很大差异,1/(m-1)就可以取值小一些,一般来说m取1.5,这样就足够了。

最后要注意迭代次数不宜过多,一般两次足够,因为考虑的是概率,如果迭代次数过多,中心点偏移较大,很可能得到数据到各个类别的概率都相差不大。


下面用JAVA实现的FuzzyKmeans,每一条数据都是一个200维的向量,使用时可以指定初始中心点,中心点的向量需要从待聚类数据中查找得到。

首先是处理输入的类:

package kmeans;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;



public class Word2VEC {
	
    private HashMap wordMap = new HashMap();

    public void loadVectorFile(String path) throws IOException {
        BufferedReader br = null;
        double len = 0;
        double vector = 0;
        int size=0;
        try {
        	File f = new File(path);
        	br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
            String word;
            String line="";
            String[] outline=new String[210];
            double[] vectors = null;
            int count=0;
            while((line=br.readLine())!=null){
            	if(count%100000==0){
            		System.out.println("read: "+count);
            	}
            	count++;
            	outline=line.split(",");
            	size=outline.length-1;
                word = outline[0];
                vectors = new double[size];
                len = 0;
                for (int j = 0; j < size; j++) {
                    vector = Float.parseFloat(outline[j+1]);
                    len += vector * vector;
                    vectors[j] = (double) vector;
                }
                len = Math.sqrt(len);
                for (int j = 0; j < size; j++) {
                    vectors[j] /= len;
                }
                wordMap.put(word, vectors);
            }
        } 
        finally {
        	System.out.println("total word: "+wordMap.size()+" vector dimensions: "+size);
        	br.close();
        }
    }

    public HashMap getWordMap() {
        return wordMap;
    }
    
    //calculate how many center point in the samples
    public List loadPointFile(String point_path) throws IOException{
    	File f = new File(point_path);
		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
		String line="";
		List center=new ArrayList();
		while((line=br.readLine())!=null){
			if(wordMap.containsKey(line)){
				center.add(line);
			}
		}
		br.close();
		return center;
    }
}

然后是聚类的类:

package kmeans;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;






import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;

public class FuzzyKmeans {
	 	private HashMap wordMap = null;
	    private int iter;
	    private Classes[] cArray = null;
	    public static HashMap wordcenter=new HashMap();
	    
		//total 659624 words each is a 200 vector
		//args[0] is the word vectors csv file
		//args[1] is the output file 
		//args[2] is the cluster number
		//args[3] is the iterator number
	    public static void main(String[] args) throws IOException, ParseException {
	    	
	    	String source_path;
	    	String output_path;
	    	int cluster_num = 10;
	    	int iterator_num = 10;
	    	double m=1.5;
	
	    	
	    	String point_path = null;
	    	
	    	 Options options = new Options();  
	         options.addOption("h", false, "help"); //参数不可用
	         options.addOption("i", true, "input file path"); //参数可用     
	         options.addOption("o", true, "output file path"); //参数可用 
	         options.addOption("c", true, "cluster number, default 10"); //参数可用 
	         options.addOption("x", true, "iterator number, default 10"); //参数可用 
	         options.addOption("p", true, "the center point"); //参数可用
	         options.addOption("m", true, "the parameter for fuzzy kmeans"); //参数可用
	         
	         CommandLineParser parser = new PosixParser();  
	         CommandLine cmd = parser.parse(options, args);  
	   
	         if (cmd.hasOption("i"))  
	         {  
	         	source_path = cmd.getOptionValue("i");  
	         }else{
	        	 HelpFormatter formatter = new HelpFormatter();  
	             formatter.printHelp( "help", options ); 
	             return;
	         }
	         
	         if (cmd.hasOption("o"))  
	         {  
	        	 output_path = cmd.getOptionValue("o");  
	         }else{
	        	 HelpFormatter formatter = new HelpFormatter();  
	             formatter.printHelp( "help", options ); 
	             return;
	         }
	   
	         if (cmd.hasOption("c"))  
	         {  
	        	 cluster_num = Integer.parseInt(cmd.getOptionValue("c"));  
	         }
	         if (cmd.hasOption("m"))  
	         {  
	        	 m = Double.parseDouble(cmd.getOptionValue("m"));  
	         }
	         
	         if (cmd.hasOption("x"))  
	         {  
	        	 iterator_num = Integer.parseInt(cmd.getOptionValue("x"));  
	         }
	         if (cmd.hasOption("p"))  
	         {  
	        	 point_path = cmd.getOptionValue("p");  
	         }
	         
	         if (cmd.hasOption("h"))  
	         {  
	             HelpFormatter formatter = new HelpFormatter();  
	             formatter.printHelp( "help", options ); 
	         }
	         
	        Word2VEC vec = new Word2VEC();
	        vec.loadVectorFile(source_path);
	        System.out.println("load data ok!");
	        
	        
	        List center=new ArrayList();
	        if(point_path!=null){
	        	center=vec.loadPointFile(point_path);
	        	if(cluster_num> result=explain[i].getMember();
	            StringBuffer buf = new StringBuffer();
	            for (int j = 0; j < result.size(); j++) {
	            	buf.append(i+"\t"+wordcenter.get(i)+"\t"+result.get(j).getKey()+"\t"+String.format("%.6f", result.get(j).getValue())+"\n");
	            }
	            bw.write(buf.toString());
	            bw.flush();
	        }
	        bw.close();
	        
	        for(int i=0;i wordMap, int clcn, int iter) {
	        this.wordMap = wordMap;
	        this.iter = iter;
	        cArray = new Classes[clcn];
	    }

	    public Classes[] explain(String point_path,double m,List center) throws IOException, FileNotFoundException {
	    	Iterator> iterator = wordMap.entrySet().iterator();
	    	//cluster number is the same as the center point number
	    	if(cArray.length==center.size()){
	    		String word="";
	    		for (int i = 0; i < cArray.length; i++) {
	    		    word=center.get(i);
	    		    cArray[i] = new Classes(i, wordMap.get(word));
	    		    wordcenter.put(i, word);
	    		    System.out.println(new String(word.getBytes("UTF-8")));
	 	        }
	    	}
	    	
	    	else{
		    	if(point_path==null){
			        for (int i = 0; i < cArray.length; i++) {
			            Entry next = iterator.next();
			            cArray[i] = new Classes(i, next.getValue());
			        }
		    	}
		    	else{
		    		String word="";
		    		File f = new File(point_path);
		    		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
		    		for (int i = 0; i < cArray.length; i++) {
		    		    word=br.readLine();
		    		    if(wordMap.containsKey(word)){
		    		    	cArray[i] = new Classes(i, wordMap.get(word));
		    		    	wordcenter.put(i, word);
		    		    	System.out.println(new String(word.getBytes("UTF-8")));
		    		    }
		    		    else{
		    		    	Entry next = iterator.next();
		    		    	cArray[i] = new Classes(i, next.getValue());
		    		    	wordcenter.put(i, next.getKey());
		    		    }
		 	        }
		    		br.close();
		    	}
	    	}
	    	
	    	
	    	
	    	iterator = wordMap.entrySet().iterator();
	    	HashMap num_wordmap=new HashMap();
            HashMap num_vecmap=new HashMap();
            //put word to the map
            int count=0;
            while (iterator.hasNext()) {
            	Entry next = iterator.next();
            	num_wordmap.put(count, next.getKey());
            	num_vecmap.put(count, next.getValue());
            	count++;
            }
	    	
	    	
	    	//begin iterator step
	        for (int i = 0; i < iter; i++) {	
	            for (Classes classes : cArray) {
	                classes.clean();
	            }
	            
	            double u[][]=new double[cArray.length][count];
	            
	            int cnt = 0;
	            int num=0;
	            
	            while (num values = new HashMap();
	        //calculate the distance between point and center
	        public double distance(double[] value) {
	            double sum = 0;
	            for (int i = 0; i < value.length; i++) {
	                sum += (center[i] - value[i])*(center[i] - value[i]) ;
	            }
	            return sum ;
	        }
	        
	        //put word and its probability
	        public void putValue(String word, double score) {
	            values.put(word, score);
	        }

	        public void updateCenter(HashMap num_vecmap,double[][] u,int count,double m) {
	            for (int i = 0; i < center.length; i++) {
	                center[i] = 0;
	            }
	            double[] value = null;
	            
	            for(int j=0;j> getMember() {
	            List> arrayList = new ArrayList>(
	                values.entrySet());
	            int count=arrayList.size();
	            if(count<=0){
	            	return Collections.emptyList() ;
	          	}
	            return arrayList;
	        }
	    }
}


你可能感兴趣的:(Felven在职场)