使用过word2vec的人都知道,使用word2vec训练得到的结果是每个词对应一个向量。虽然word2vec提供了kmeans的聚类方法,但是它是对所有原始的词进行聚类,如果我们只需要对其中一部分词按照向量进行kmeans聚类,那只好自己写方法。
参考网上一个开源的JAVA版 word2vec,可以得到JAVA版的kmeans聚类,输入为一个csv文件,每一行为词和其向量,输出为词的类别,词,该词到中心词的距离。
代码如下:
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
public class Word2VEC {
private HashMap wordMap = new HashMap();
public void loadVectorFile(String path) throws IOException {
BufferedReader br = null;
double len = 0;
float vector = 0;
int size=0;
try {
File f = new File(path);
br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
String word;
String line="";
String[] outline=new String[210];
float[] vectors = null;
while((line=br.readLine())!=null){
outline=line.split(",");
size=outline.length-1;
word = outline[0];
vectors = new float[size];
len = 0;
for (int j = 0; j < size; j++) {
vector = Float.parseFloat(outline[j+1]);
len += vector * vector;
vectors[j] = (float) vector;
}
len = Math.sqrt(len);
for (int j = 0; j < size; j++) {
vectors[j] /= len;
}
wordMap.put(word, vectors);
}
}
finally {
System.out.println("total word: "+wordMap.size()+" vector dimensions: "+size);
br.close();
}
}
public HashMap getWordMap() {
return wordMap;
}
}
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
public class WordKmeans {
private HashMap wordMap = null;
private int iter;
private Classes[] cArray = null;
//total 659624 words each is a 200 vector
//args[0] is the word vectors csv file
//args[1] is the output file
//args[2] is the cluster number
//args[3] is the iterator number
public static void main(String[] args) throws IOException {
Word2VEC vec = new Word2VEC();
vec.loadVectorFile(args[0]);
System.out.println("load data ok!");
//input cluster number and iterator number
WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), Integer.parseInt(args[2]),Integer.parseInt(args[3]));
Classes[] explain = wordKmeans.explain();
File fw = new File(args[1]);
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fw), "UTF-8"));
//explain.length is the classes number
for (int i = 0; i < explain.length; i++) {
List> result=explain[i].getMember();
StringBuffer buf = new StringBuffer();
for (int j = 0; j < result.size(); j++) {
buf.append(i+"\t"+result.get(j).getKey()+"\t"+result.get(j).getValue().toString()+"\n");
}
bw.write(buf.toString());
bw.flush();
}
bw.close();
}
public WordKmeans(HashMap wordMap, int clcn, int iter) {
this.wordMap = wordMap;
this.iter = iter;
cArray = new Classes[clcn];
}
public Classes[] explain() {
Iterator> iterator = wordMap.entrySet().iterator();
for (int i = 0; i < cArray.length; i++) {
Entry next = iterator.next();
cArray[i] = new Classes(i, next.getValue());
}
for (int i = 0; i < iter; i++) {
for (Classes classes : cArray) {
classes.clean();
}
iterator = wordMap.entrySet().iterator();
int cnt = 0;
while (iterator.hasNext()) {
if(cnt % 10000 ==0)
{
System.out.println("Iter:"+i+"\tword:"+(cnt));
}
cnt++;
Entry next = iterator.next();
double miniScore = Double.MAX_VALUE;
double tempScore;
int classesId = 0;
for (Classes classes : cArray) {
tempScore = classes.distance(next.getValue());
if (miniScore > tempScore) {
miniScore = tempScore;
classesId = classes.id;
}
}
cArray[classesId].putValue(next.getKey(), miniScore);
}
for (Classes classes : cArray) {
classes.updateCenter(wordMap);
}
System.out.println("iter " + i + " ok!");
}
return cArray;
}
public static class Classes {
private int id;
private float[] center;
public Classes(int id, float[] center) {
this.id = id;
this.center = center.clone();
}
Map values = new HashMap<>();
public double distance(float[] value) {
double sum = 0;
for (int i = 0; i < value.length; i++) {
sum += (center[i] - value[i])*(center[i] - value[i]) ;
}
return sum ;
}
public void putValue(String word, double score) {
values.put(word, score);
}
public void updateCenter(HashMap wordMap) {
for (int i = 0; i < center.length; i++) {
center[i] = 0;
}
float[] value = null;
for (String keyWord : values.keySet()) {
value = wordMap.get(keyWord);
for (int i = 0; i < value.length; i++) {
center[i] += value[i];
}
}
for (int i = 0; i < center.length; i++) {
center[i] = center[i] / values.size();
}
}
public void clean() {
values.clear();
}
public List> getTop(int n) {
List> arrayList = new ArrayList>(
values.entrySet());
Collections.sort(arrayList, new Comparator>() {
@Override
public int compare(Entry o1, Entry o2) {
return o1.getValue() > o2.getValue() ? 1 : -1;
}
});
int min = Math.min(n, arrayList.size() - 1);
if(min<=1){
return Collections.emptyList() ;
}
return arrayList.subList(0, min);
}
public List> getMember() {
List> arrayList = new ArrayList>(
values.entrySet());
Collections.sort(arrayList, new Comparator>() {
@Override
public int compare(Entry o1, Entry o2) {
return o1.getValue() > o2.getValue() ? 1 : -1;
}
});
int count=arrayList.size() - 1;
if(count<=1){
return Collections.emptyList() ;
}
return arrayList.subList(0, count);
}
}
}