对词向量进行Kmeans距离

使用过word2vec的人都知道,使用word2vec训练得到的结果是每个词对应一个向量。虽然word2vec提供了kmeans的聚类方法,但是它是对所有原始的词进行聚类,如果我们只需要对其中一部分词按照向量进行kmeans聚类,那只好自己写方法。

参考网上一个开源的JAVA版 word2vec,可以得到JAVA版的kmeans聚类,输入为一个csv文件,每一行为词和其向量,输出为词的类别,词,该词到中心词的距离。

代码如下:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;



public class Word2VEC {
	
    private HashMap wordMap = new HashMap();

    public void loadVectorFile(String path) throws IOException {
        BufferedReader br = null;
        double len = 0;
        float vector = 0;
        int size=0;
        try {
        	File f = new File(path);
        	br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
            String word;
            String line="";
            String[] outline=new String[210];
            float[] vectors = null;
            while((line=br.readLine())!=null){
            	outline=line.split(",");
            	size=outline.length-1;
                word = outline[0];
                vectors = new float[size];
                len = 0;
                for (int j = 0; j < size; j++) {
                    vector = Float.parseFloat(outline[j+1]);
                    len += vector * vector;
                    vectors[j] = (float) vector;
                }
                len = Math.sqrt(len);
                for (int j = 0; j < size; j++) {
                    vectors[j] /= len;
                }
                wordMap.put(word, vectors);
            }
        } 
        finally {
        	System.out.println("total word: "+wordMap.size()+" vector dimensions: "+size);
        	br.close();
        }
    }

    public HashMap getWordMap() {
        return wordMap;
    }
}

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

public class WordKmeans {
    private HashMap wordMap = null;
    private int iter;
    private Classes[] cArray = null;
    
	//total 659624 words each is a 200 vector
	//args[0] is the word vectors csv file
	//args[1] is the output file 
	//args[2] is the cluster number
	//args[3] is the iterator number
    public static void main(String[] args) throws IOException {
        Word2VEC vec = new Word2VEC();
        vec.loadVectorFile(args[0]);
        System.out.println("load data ok!");
        
        //input cluster number and iterator number
        WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), Integer.parseInt(args[2]),Integer.parseInt(args[3]));
        Classes[] explain = wordKmeans.explain();

        File fw = new File(args[1]);
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fw), "UTF-8"));
	    
	    //explain.length is the classes number
        for (int i = 0; i < explain.length; i++) {
            List> result=explain[i].getMember();
            StringBuffer buf = new StringBuffer();
            for (int j = 0; j < result.size(); j++) {
            	buf.append(i+"\t"+result.get(j).getKey()+"\t"+result.get(j).getValue().toString()+"\n");
            }
            bw.write(buf.toString());
            bw.flush();
        }
        bw.close();
    }


    public WordKmeans(HashMap wordMap, int clcn, int iter) {
        this.wordMap = wordMap;
        this.iter = iter;
        cArray = new Classes[clcn];
    }

    public Classes[] explain() {
        Iterator> iterator = wordMap.entrySet().iterator();
        for (int i = 0; i < cArray.length; i++) {
            Entry next = iterator.next();
            cArray[i] = new Classes(i, next.getValue());
        }
        for (int i = 0; i < iter; i++) {
            for (Classes classes : cArray) {
                classes.clean();
            }
            iterator = wordMap.entrySet().iterator();
            int cnt = 0;
            while (iterator.hasNext()) {
            	if(cnt % 10000 ==0)
            	{
            		System.out.println("Iter:"+i+"\tword:"+(cnt));
            	}
            	cnt++;
                Entry next = iterator.next();
                double miniScore = Double.MAX_VALUE;
                double tempScore;
                int classesId = 0;
                for (Classes classes : cArray) {
                    tempScore = classes.distance(next.getValue());
                    if (miniScore > tempScore) {
                        miniScore = tempScore;
                        classesId = classes.id;
                    }
                }
                cArray[classesId].putValue(next.getKey(), miniScore);
            }
            for (Classes classes : cArray) {
                classes.updateCenter(wordMap);
            }
            System.out.println("iter " + i + " ok!");
        }
        return cArray;
    }

    public static class Classes {
        private int id;
        private float[] center;
        public Classes(int id, float[] center) {
            this.id = id;
            this.center = center.clone();
        }

        Map values = new HashMap<>();
        public double distance(float[] value) {
            double sum = 0;
            for (int i = 0; i < value.length; i++) {
                sum += (center[i] - value[i])*(center[i] - value[i]) ;
            }
            return sum ;
        }

        public void putValue(String word, double score) {
            values.put(word, score);
        }

        public void updateCenter(HashMap wordMap) {
            for (int i = 0; i < center.length; i++) {
                center[i] = 0;
            }
            float[] value = null;
            for (String keyWord : values.keySet()) {
                value = wordMap.get(keyWord);
                for (int i = 0; i < value.length; i++) {
                    center[i] += value[i];
                }
            }
            for (int i = 0; i < center.length; i++) {
                center[i] = center[i] / values.size();
            }
        }

        public void clean() {
            values.clear();
        }

        public List> getTop(int n) {
            List> arrayList = new ArrayList>(
                values.entrySet());
            Collections.sort(arrayList, new Comparator>() {
                @Override
                public int compare(Entry o1, Entry o2) {
                    return o1.getValue() > o2.getValue() ? 1 : -1;
                }
            });
            int min = Math.min(n, arrayList.size() - 1);
            if(min<=1){
            	return Collections.emptyList() ;
            }
            return arrayList.subList(0, min);
        }
        
        public List> getMember() {
            List> arrayList = new ArrayList>(
                values.entrySet());
            Collections.sort(arrayList, new Comparator>() {
                @Override
                public int compare(Entry o1, Entry o2) {
                    return o1.getValue() > o2.getValue() ? 1 : -1;
                }
            });
            int count=arrayList.size() - 1;
            if(count<=1){
            	return Collections.emptyList() ;
            }
            return arrayList.subList(0, count);
        }
    }
}

进行聚类时需要指定输入文件,输出文件,类别数目和迭代次数。目前经过试验觉得JAVA速度不是很快,不知道C版本的速度如何,估计肯定要比JAVA快很多。

你可能感兴趣的:(Felven在职场)