从原始文档到KNN分类算法实现(一)

  1. 建立原始文档
    ~/corpus/C1
    下建三个文件:0,1,2。内容分别为:

    0

    眼睛 明亮 健康 身体 发达

    1

    高大 身材 胳膊 勇猛 四肢

    2

    胳膊 健康 身体 高大 健康

    ~/corpus/C2下建三个文件:3,4,5。内容分别为:

    3

    美国 很 发达

    4

    美国 经济 富强

    5

    日本 富强 科技 发达

  2. 新建文件夹~/frequency,运行程序WordCount.java。该程序的输入文件是~/corpus,运行结束后在~/frequency下产生了6个新文件:0—5,分别记录了原始文档中每个词出现的次数。

    View Code
     1 /**
     2  * Author: Orisun
     3  * Date: Aug 29, 2011
     4  * FileName: WordCount.java
     5  * Function: calculate word frequency in each document.
     6  */
     7 package featureselect;
     8 
     9 import java.io.BufferedReader;
    10 import java.io.BufferedWriter;
    11 import java.io.File;
    12 import java.io.FileReader;
    13 import java.io.FileWriter;
    14 import java.util.HashMap;
    15 import java.util.Iterator;
    16 import java.util.Map.Entry;
    17 
    18 public class WordCount {
    19 
    20     //参数文件是(经过预处理的)原始训练文档集所在的文件夹
    21     public void wordCount(File srcFile) {
    22         if (srcFile.isDirectory()) {
    23             File[] children = srcFile.listFiles();
    24             for (File child : children) {
    25                 wordCount(child);
    26             }
    27         } else if (srcFile.isFile()) {
    28             HashMap<String,Integer> wordfreq=new HashMap<String,Integer>();        //存储单词出现的次数
    29             try {
    30                 FileReader fr = new FileReader(srcFile);
    31                 BufferedReader br = new BufferedReader(fr);
    32                 String line=null;
    33                 while((line=br.readLine())!=null){
    34                     String[] words=line.split("\\s+");
    35                     for(String word:words){
    36                         if(!wordfreq.containsKey(word)){
    37                             wordfreq.put(word, new Integer(1));
    38                         }else{
    39                             int n=wordfreq.get(word);
    40                             wordfreq.put(word, new Integer(n+1));
    41                         }
    42                     }
    43                 }
    44                 br.close();
    45                 
    46                 //把词频写入新文件
    47                 File newFile=new File("/home/orisun/frequency",srcFile.getName());
    48                 newFile.createNewFile();
    49                 FileWriter fw=new FileWriter(newFile);
    50                 BufferedWriter bw=new BufferedWriter(fw);
    51                 Iterator<Entry<String, Integer>> iter=wordfreq.entrySet().iterator();
    52                 while(iter.hasNext()){
    53                     Entry<String,Integer> entry=iter.next();
    54                     bw.write(entry.getKey());
    55                     bw.write("\t");
    56                     bw.write(String.valueOf(entry.getValue()));
    57                     bw.newLine();
    58                 }
    59                 bw.flush();
    60                 bw.close();
    61             } catch (Exception e) {
    62                 e.printStackTrace();
    63             }
    64         }
    65     }
    66     
    67     public static void main(String[] args){
    68         WordCount inst=new WordCount();
    69         File file=new File("/home/orisun/corpus");
    70         if(!file.exists()){
    71             System.out.println("文件不存在,程序退出.");
    72             System.exit(2);
    73         }
    74         inst.wordCount(file);
    75     }
    76 }
  3. 下一步是运行WordDocMatrix.java。运行时它的输入是第2步产生的~/frequency下的六个文件,输出是一个文件~/matrix,该文件记录了每个单词在各个文档中出现的次数。
      0 1 2 3 4 5
    发达 1 0 0 1 0 1
    健康 1 0 2 0 0 0
    富强 0 0 0 0 1 1
    日本 0 0 0 0 0 1
    四肢 0 1 0 0 0 0
    眼睛 1 0 0 0 0 0
    高大 0 1 1 0 0 0
    0 0 0 1 0 0
    身体 1 0 1 0 0 0
    身材 0 1 0 0 0 0
    明亮 1 0 0 0 0 0
    胳膊 0 1 1 0 0 0
    经济 0 0 0 0 1 0
    勇猛 0 1 0 0 0 0
    科技 0 0 0 0 0 1
    美国 0 0 0 1 1 0
View Code
  1 /**
  2  * Author: Orisun
  3  * Date: Aug 29, 2011
  4  * FileName: WordDocMatrix.java
  5  * Function: word-document matrix denote each word frequency in each document.
  6  */
  7 package featureselect;
  8 
  9 import java.io.BufferedReader;
 10 import java.io.File;
 11 import java.io.FileReader;
 12 import java.io.PrintStream;
 13 import java.util.ArrayList;
 14 import java.util.HashMap;
 15 import java.util.Iterator;
 16 import java.util.Map.Entry;
 17 
 18 public class WordDocMatrix {
 19 
 20     // 训练集分为3类,每类各1000个文档
 21     public static final int docnumber = 3000; 
 22     public static final int classnumber = 3;
 23     //matrix记录各个词在各个文档中出现过几次
 24     HashMap<String, ArrayList<Short>> matrix = new HashMap<String, ArrayList<Short>>();
 25 
 26     //参数文件是文件夹,其中的文件分别存储了文档中各个单词出现的次数
 27     public void buildMatrix(File srcFile) {
 28         if (srcFile.isDirectory()) {
 29             File[] children = srcFile.listFiles();
 30             for (File child : children) {
 31                 buildMatrix(child);
 32             }
 33         } else if (srcFile.isFile()) {
 34             int filename=Integer.parseInt(srcFile.getName());
 35             try {
 36                 FileReader fr = new FileReader(srcFile);
 37                 BufferedReader br = new BufferedReader(fr);
 38                 String line=null;
 39                 while((line=br.readLine())!=null){
 40                     String[] pair=line.split("\\s+");
 41                     if(!matrix.containsKey(pair[0])){
 42                         ArrayList<Short> al=new ArrayList<Short>(docnumber);
 43                         for(int i=0;i<docnumber;i++)
 44                             al.add((short)0);
 45                         al.set(filename, Short.parseShort(pair[1]));
 46                         matrix.put(pair[0], al);
 47                     }else{
 48                         ArrayList<Short> al=matrix.get(pair[0]);
 49                         short orig=al.get(filename);
 50                         al.set(filename, (short)(orig+Short.parseShort(pair[1])));
 51                     }
 52                 }
 53                 br.close();
 54             } catch (Exception e) {
 55                 e.printStackTrace();
 56             }
 57         }
 58     }
 59     
 60     //打印矩阵的前几行,输出到文件,以作验证(如果全部打印文件会因太大而加载过慢,甚至可能打不开)
 61     public static void main(String[] args){
 62         WordDocMatrix inst=new WordDocMatrix();
 63         try{
 64             File Mfile=new File("/home/orisun/freq");
 65             if (!Mfile.exists()) {
 66                 System.out.println("文件不存在,程序退出.");
 67                 System.exit(2);
 68             }
 69             inst.buildMatrix(Mfile);
 70             File file=new File("/home/orisun/matrix");
 71             file.createNewFile();
 72             PrintStream ps=new PrintStream(file);
 73             inst.printMatrix(ps,inst.matrix);
 74             //inst.printMatrix(System.out,inst.matrix);
 75         } catch (Exception e) {
 76             e.printStackTrace();
 77         }
 78     }
 79     
 80     //输出matrix
 81     public void printMatrix(PrintStream out,HashMap<String, ArrayList<Short>> matrix){
 82         Iterator<Entry<String, ArrayList<Short>>> iter=matrix.entrySet().iterator();
 83         try{
 84         while(iter.hasNext()){
 85             Entry<String, ArrayList<Short>> entry=iter.next();
 86             out.print(entry.getKey());
 87             out.print("\t");
 88             for(int i=0;i<docnumber;i++){
 89                 out.print(String.valueOf(entry.getValue().get(i)));
 90                 out.print("\t");
 91             }
 92             out.println();
 93         }
 94         out.flush();
 95         out.close();
 96         }catch(Exception e){
 97             e.printStackTrace();
 98         }
 99     }
100 }

4.IG.java在整个算法过程中也不必单独运行,它只是为后面的特征词选择提供函数调用。运行时它的输入是~/frequency下的六个文件,通过调从文件读入word-doc矩阵,最后计算并输出各个词的信息增益值IG

发达 0.0566330123
健康 0.3182570841
富强 0.3182570841
日本 0.1323041247
四肢 0.1323041247
眼睛 0.1323041247
高大 0.3182570841
0.1323041247
身体 0.3182570841
身材 0.1323041247
明亮 0.1323041247
胳膊 0.3182570841
经济 0.1323041247
勇猛 0.1323041247
科技 0.1323041247
美国 0.3182570841
View Code
  1 /**
  2  * Author: Orisun
  3  * Date: Aug 29, 2011
  4  * FileName: IG.java
  5  * Function: calculate the information gain of each word in documents
  6  */
  7 package featureselect;
  8 
  9 import java.io.BufferedReader;
 10 import java.io.File;
 11 import java.io.FileReader;
 12 import java.io.IOException;
 13 import java.util.ArrayList;
 14 import java.util.HashMap;
 15 import java.util.Iterator;
 16 import java.util.Map.Entry;
 17 
 18 public class IG {
 19 
 20     // matrix记录各个词在各个文档中出现过几次
 21     public static HashMap<String, ArrayList<Short>> matrix = new HashMap<String, ArrayList<Short>>();
 22     HashMap<String, Double> map = new HashMap<String, Double>(); // 存储每个单词的信息增益值
 23 
 24     //参数文件存储word-doc矩阵
 25     public static void initMatrix(File srcFile) {
 26         if (!(srcFile.exists() && srcFile.isFile())) {
 27             System.out.println("Matrix文件不存在.程序退出.");
 28             System.exit(2);
 29         }
 30         try {
 31             FileReader fr = new FileReader(srcFile);
 32             BufferedReader br=new BufferedReader(fr);                    
 33             String line=null;
 34             while((line=br.readLine())!=null){        
 35                 //System.out.println(line);
 36                 String[] content=line.split("\\s+");
 37                 String word=content[0];
 38                 //System.out.print(word+"\t");
 39                 ArrayList<Short> al=new ArrayList<Short>(WordDocMatrix.docnumber);
 40                 for(int i=0;i<WordDocMatrix.docnumber;i++){
 41                     short count=Short.parseShort(content[i+1]);
 42                     //System.out.print(count+"\t");
 43                     al.add(count);
 44                 }
 45                 //System.out.println();
 46                 matrix.put(word, al);
 47             }
 48             //System.out.println("word-doc矩阵的行数:"+matrix.size());
 49             br.close();
 50         } catch (IOException e) {
 51             e.printStackTrace();
 52         }
 53     }
 54 
 55     //参数文件是文件夹,其中的文件分别存储了文档中各个单词出现的次数
 56     public void calIG(File srcFile) {
 57         initMatrix(new File("/home/orisun/matrix"));
 58         Iterator<Entry<String, ArrayList<Short>>> iter = matrix.entrySet().iterator();
 59         double entropy = Math.log(WordDocMatrix.classnumber);
 60         while (iter.hasNext()) {
 61             Entry<String, ArrayList<Short>> entry = iter.next();
 62             String word = entry.getKey();
 63             ArrayList<Short> al = entry.getValue();
 64             int category = WordDocMatrix.docnumber / WordDocMatrix.classnumber;
 65             int wcount = 0; // 出现word的文档的文档数量
 66             int[] wcount_class = new int[WordDocMatrix.classnumber];// 每个类别中出现单词word的文档数
 67             double pw = 0.0; // 出现word的文档占全部文档的比重
 68             double[] pcw = new double[WordDocMatrix.classnumber]; // 在单词word出现时各个类别所占的比重
 69             double[] pcw_b = new double[WordDocMatrix.classnumber]; // 在单词word不出现时各个类别所占的比重
 70             for (int i = 0; i < WordDocMatrix.classnumber; i++) {
 71                 for (int j = 0; j < category; j++) {
 72                     if (al.get(j + i * category) > 0) {
 73                         wcount_class[i]++;
 74                     }
 75                 }
 76                 wcount += wcount_class[i];
 77             }
 78             pw = 1.0 * wcount / WordDocMatrix.docnumber;
 79             for (int i = 0; i < WordDocMatrix.classnumber; i++) {
 80                 pcw[i] = 1.0 * wcount_class[i] / wcount;
 81                 pcw_b[i] = 1.0 * (category - wcount_class[i])
 82                         / (WordDocMatrix.docnumber - wcount);
 83             }
 84             double d1 = 0.0;
 85             double d2 = 0.0;
 86             for (int i = 0; i < WordDocMatrix.classnumber; i++) {
 87                 d1 += pcw[i] * Math.log(pcw[i] + Double.MIN_VALUE);
 88                 d2 += pcw_b[i] * Math.log(pcw_b[i] + Double.MIN_VALUE);
 89             }
 90             double ig = entropy + pw * d1 + (1.0 - pw) * d2;
 91             map.put(word, ig);
 92             //System.out.println(word+"\t"+ig);
 93         }
 94     }
 95     
 96     public static void main(String[] args){
 97         IG inst=new IG();
 98         File matrixFile=new File("/home/orisun/matrix");
 99         initMatrix(matrixFile);
100         //new featureselect.WordDocMatrix().printMatrix(System.out, matrix);
101         File freqFile=new File("/home/orisun/frequency");
102         inst.calIG(freqFile);
103     }
104 }

5.FS.java需要运行。它的输入是~/frequency下的六个文件,调用IG.java中的calIG方法,并对各单词的信息增益进行从大到小排序,最后把IG值最大的N个特征项输出到文件~/features中,~/featuresFS.java运行时创建的。这里我们选择6个特征词,IG值从大到小排

  1. 健康 富强 高大 身体 胳膊 美国

    这也是在FS.features中的存储顺序(FS.featuresArrayList类型)。

  2. View Code
     1 /**
     2  * Author: Orisun
     3  * Date: Aug 29, 2011
     4  * FileName: FS.java
     5  * Function: feature select with information gain
     6  */
     7 package featureselect;
     8 
     9 import java.io.BufferedWriter;
    10 import java.io.File;
    11 import java.io.FileWriter;
    12 import java.util.ArrayList;
    13 import java.util.Collections;
    14 import java.util.HashMap;
    15 import java.util.Iterator;
    16 import java.util.Map.Entry;
    17 import java.util.Map;
    18 import java.util.Comparator;
    19 
    20 public class FS {
    21 
    22     public static ArrayList<String> features=new ArrayList<String>();//存放最终选择的特征词
    23     
    24     //选择信息增益值最大的n个单词作为特征项
    25     public void featureSelect(int n){
    26         IG inst=new IG();
    27         inst.calIG(new File("/home/orisun/frequency"));
    28         ArrayList<Entry<String,Double>> list=sort(inst.map);
    29         Iterator<Entry<String,Double>> iter=list.iterator();
    30         int index=0;
    31         while(index++<n && iter.hasNext()){
    32             Entry<String,Double> entry=iter.next();
    33             //System.out.println(entry.getKey()+"  "+entry.getValue());
    34             features.add(entry.getKey());
    35         }
    36     }
    37     
    38     //Map按value进行排序
    39     public ArrayList<Entry<String,Double>> sort(HashMap<String,Double> arg){
    40         ArrayList<Entry<String,Double>> al=new ArrayList<Entry<String,Double>>(arg.entrySet());
    41         Collections.sort(al, new Comparator<Map.Entry<String,Double>>(){
    42             public int compare(Map.Entry<String, Double> o1,Map.Entry<String,Double> o2){
    43                 double res=o2.getValue()-o1.getValue();
    44                 if(res<0)
    45                     return -1;
    46                 else if(res>0)
    47                     return 1;
    48                 else
    49                     return 0;
    50         }
    51         });
    52         return al;
    53     }
    54     
    55     //把最终选择的特征词存入文件
    56     public static void main(String[] args){
    57         FS inst=new FS();
    58         inst.featureSelect(6);
    59         try{    
    60             File file=new File("/home/orisun/features");
    61             file.createNewFile();
    62             FileWriter fw=new FileWriter(file);
    63             BufferedWriter bw=new BufferedWriter(fw);
    64             Iterator<String> iter=FS.features.iterator();
    65             while(iter.hasNext()){
    66                 String feature=iter.next();
    67                 bw.write(feature);
    68                 bw.newLine();
    69             }
    70             bw.flush();
    71             bw.close();
    72         } catch (Exception e) {
    73             e.printStackTrace();
    74         }
    75     }
    76 }
    
6.建立文档向量模型, TDVM.java需要运行。它首先从文件 ~/features中读取特征项存入 TDVM.features中( DVM.featuresHashSet类型),此时特征项在 TDVM.features中的顺序为:
  1. 健康 富强 胳膊 高大 美国 身体

    然后调用IG.initMatrix,从文件中读入word-doc矩阵,计算出特征项在文档中的权重,最终得出文档向量模型。
    归一化之前6个文档向量分别为:

    1.0986122886

    0.0

    0.0

    0.0

    0.0

    1.0986122886

    0.0

    0.0

    1.0986122886

    1.0986122886

    0.0

    0.0

    2.1972245773

    0.0

    1.0986122886

    1.0986122886

    0.0

    1.0986122886

    0.0

    0.0

    0.0

    0.0

    1.0986122886

    0.0

    0.0

    1.0986122886

    0.0

    0.0

    1.0986122886

    0.0

    0.0

    1.0986122886

    0.0

    0.0

    0.0

    0.0

    归一化之后6个文档向量分别为:

0.4472135954

0.0

0.0

0.0

0.0

0.70710678118

0.0

0.0

0.70710678118

0.70710678118

0.0

0.0

0.8944271909

0.0

0.70710678118

0.70710678118

0.0

0.70710678118

0.0

0.0

0.0

0.0

0.70710678118

0.0

0.0

0.70710678118

0.0

0.0

0.70710678118

0.0

0.0

0.70710678118

0.0

0.0

0.0

0.0

View Code
  1 /**
  2  * Author: Orisun
  3  * Date: Aug 29, 2011
  4  * FileName: DVM.java
  5  * Function: build document vector model for documents in train
  6  */
  7 package DVM;
  8 
  9 import java.io.BufferedReader;
 10 import java.io.BufferedWriter;
 11 import java.io.File;
 12 import java.io.FileReader;
 13 import java.io.FileWriter;
 14 import java.util.ArrayList;
 15 import java.util.HashMap;
 16 import java.util.HashSet;
 17 import java.util.Iterator;
 18 import java.util.Map.Entry;
 19 import java.util.Vector;
 20 
 21 public class TVSM {
 22 
 23     HashSet<String> features = new HashSet<String>();// 存放最终选择的特征词
 24     int len; // 特征项的个数,亦即文档向量的长度
 25     File path = new File("/home/orisun/dvm"); // 存放文档向量的路径
 26     HashMap<String, Vector<Double>> dv = new HashMap<String, Vector<Double>>(); // key是文档名,value是文档对应的规一化之前的向量
 27     double[] sqrt;        //存储向量数组每个位置上的平方和
 28 
 29     public void printFeature() {
 30         Iterator<String> iter = features.iterator();
 31         while (iter.hasNext()) {
 32             System.out.print(iter.next() + "\t");
 33         }
 34         System.out.println();
 35     }
 36 
 37     // 从文件中读入特征项。参数文件存储经过特征选择后剩下的特征项。
 38     public void initFeatures(File file) {
 39         try {
 40             FileReader fr = new FileReader(file);
 41             BufferedReader br = new BufferedReader(fr);
 42             String line = null;
 43             while ((line = br.readLine()) != null) {
 44                 features.add(line.split("\\s+")[0]);
 45             }
 46         } catch (Exception e) {
 47             e.printStackTrace();
 48         }
 49         len = features.size();
 50         sqrt = new double[len];
 51     }
 52 
 53     public void buildDVM(File srcFile) {
 54         if (srcFile.isDirectory()) {
 55             File[] children = srcFile.listFiles();
 56             for (File child : children) {
 57                 buildDVM(child);
 58             }
 59         } else if (srcFile.isFile()) {
 60             featureselect.IG.initMatrix(new File("/home/orisun/matrix"));
 61             // new featureselect.WordDocMatrix().printMatrix(System.out,featureselect.IG.matrix);
 62             Object[] feature_array = features.toArray();
 63             HashMap<String, Double> fea_wei = new HashMap<String, Double>();
 64 
 65             int filename = Integer.parseInt(srcFile.getName());
 66             Vector<Double> v = new Vector<Double>(len);
 67             for (int i = 0; i < len; i++)
 68                 v.add(0.0);
 69             try {
 70                 FileReader fr = new FileReader(srcFile);
 71                 BufferedReader br = new BufferedReader(fr);
 72                 String line = null;
 73                 while ((line = br.readLine()) != null) { // 逐个读取文档中的词语
 74                     String word = line.split("\\s+")[0];
 75                     if (!features.contains(word))
 76                         continue;
 77                     int tf = 0; // 特征项在本文档中出现的频率
 78                     int Ni = 0; // 出现特征项的文档数目
 79                     int N = featureselect.WordDocMatrix.docnumber; // 全部文档数目
 80                     ArrayList<Short> al = featureselect.IG.matrix.get(word);
 81                     tf = al.get(filename);
 82                     for (int i = 0; i < N; i++) {
 83                         if (al.get(i) > 0)
 84                             Ni++;
 85                     }
 86                     // System.out.println("word="+word+"\tfilenmae="+filename+"\ttf="+tf+"\tNi="+Ni);
 87                     double weight = -1.0 * tf
 88                             * Math.log(1.0 * Ni / N + Double.MIN_VALUE);
 89                     fea_wei.put(word, weight);
 90                 }
 91                 for (int i = 0; i < feature_array.length; i++) {
 92                     String feat = feature_array[i].toString();
 93                     double w = 0.0;
 94                     if (fea_wei.containsKey(feat))
 95                         w = fea_wei.get(feat);
 96                     v.set(i, w);
 97                     sqrt[i] += Math.pow(w, 2);
 98                 }
 99                 dv.put(String.valueOf(filename), v);
100 
101             } catch (Exception e) {
102                 e.printStackTrace();
103             }
104 
105         }
106     }
107 
108     //文档向量归一化
109     public void unionVector() {
110         Iterator<Entry<String, Vector<Double>>> iter = dv.entrySet().iterator();
111         while (iter.hasNext()) {
112             Entry<String, Vector<Double>> entry = iter.next();
113             String fname = entry.getKey();
114             Vector<Double> v = entry.getValue();
115             Iterator<Double> it=v.iterator();
116             
117             File newFile = new File(path, fname);
118             try {
119                 newFile.createNewFile();
120                 FileWriter fw=new FileWriter(newFile);
121                 BufferedWriter bw=new BufferedWriter(fw);
122                 
123                 int index=0;
124                 while(it.hasNext()){
125                     double d=it.next();
126                     d/=Math.sqrt(sqrt[index]);
127                     //归一化后写入文件
128                     bw.write(String.valueOf(d));
129                     bw.newLine();
130                     index++;
131                 }
132                 bw.flush();
133                 bw.close();
134             } catch (Exception e) {
135                 e.printStackTrace();
136             }
137 
138         }
139     }
140 
141     public static void main(String[] args) {
142         TVSM inst = new TVSM();
143         File feaFile = new File("/home/orisun/features");
144         inst.initFeatures(feaFile);
145         // inst.printFeature();
146         File freqFile = new File("/home/orisun/corpus");
147         
148         if (!freqFile.exists()) {
149             System.out.println("文件不存在,程序退出.");
150             System.exit(2);
151         }
152         inst.buildDVM(freqFile);
153         inst.unionVector();
154     }
155 }

.新建一个文件~/unknown,文件内容为:

  1. 美国 比 日本 富强

    接下来我们要判断这个文档属于C1类还是C2类。

8.利用WordCount.java程序(当然参数文件要作修改)我们统计出unknown文件中各个词语出现的次数,写入原文件中。

9.运行ADVM.java建立unknown文件的文档向量模型,向量写入文件~/dvm2/unknown中。其文档向量为(不需要进行归一化):

  1. 0.0

    1.0986122886681098

    0.0

    0.0

    1.0986122886681098

    0.0

  2. View Code
      1 /**
      2  * Author: Orisun
      3  * Date: Aug 31, 2011
      4  * FileName: ADVM.java
      5  * Function: build document vector model for a new come document
      6  */
      7 package DVM;
      8 
      9 import java.io.BufferedReader;
     10 import java.io.BufferedWriter;
     11 import java.io.File;
     12 import java.io.FileReader;
     13 import java.io.FileWriter;
     14 import java.util.ArrayList;
     15 import java.util.HashMap;
     16 import java.util.HashSet;
     17 import java.util.Iterator;
     18 import java.util.Vector;
     19 
     20 public class AVSM {
     21 
     22     HashSet<String> features = new HashSet<String>();// 存放最终选择的特征词
     23     int len; // 特征项的个数,亦即文档向量的长度
     24     File path = new File("/home/orisun/dvm2"); // 存放文档向量的路径
     25 
     26     // 从文件中读入特征项。参数文件存储经过特征选择后剩下的特征项。
     27     public void initFeatures(File file) {
     28         try {
     29             FileReader fr = new FileReader(file);
     30             BufferedReader br = new BufferedReader(fr);
     31             String line = null;
     32             while ((line = br.readLine()) != null) {
     33                 features.add(line.split("\\s+")[0]);
     34             }
     35         } catch (Exception e) {
     36             e.printStackTrace();
     37         }
     38         len = features.size();
     39     }
     40     
     41     //参数文件存放文档中单词的频数
     42     public void buildDVM(File srcFile) {
     43         if (srcFile.isDirectory()) {
     44             File[] children = srcFile.listFiles();
     45             for (File child : children) {
     46                 buildDVM(child);
     47             }
     48         } else if (srcFile.isFile()) {
     49             featureselect.IG.initMatrix(new File("/home/orisun/matrix"));
     50             Object[] feature_array = features.toArray();
     51             HashMap<String, Double> fea_wei = new HashMap<String, Double>();
     52 
     53             String filename = srcFile.getName();
     54             File newFile = new File(path, filename);
     55             Vector<Double> v = new Vector<Double>(len);
     56             for (int i = 0; i < len; i++)
     57                 v.add(0.0);
     58             try {
     59                 newFile.createNewFile();
     60                 FileReader fr = new FileReader(srcFile);
     61                 BufferedReader br = new BufferedReader(fr);
     62                 String line = null;
     63                 while ((line = br.readLine()) != null) { // 逐个读取文档中的词语
     64                     String[] content=line.split("\\s+");
     65                     String word = content[0];
     66                     int tf=Integer.parseInt(content[1]);// 特征项在本文档中出现的频率
     67                     if (!features.contains(word))
     68                         continue;
     69                     int Ni = 0; // 出现特征项的文档数目
     70                     int N = featureselect.WordDocMatrix.docnumber; // 全部文档数目
     71                     ArrayList<Short> al = featureselect.IG.matrix.get(word);
     72                     for (int i = 0; i < N; i++) {
     73                         if (al.get(i) > 0)
     74                             Ni++;
     75                     }
     76                     //System.out.println("word="+word+"\tfilenmae="+filename+"\ttf="+tf+"\tNi="+Ni);
     77                     double weight = -1.0 * tf
     78                             * Math.log(1.0 * Ni / N + Double.MIN_VALUE);
     79                     fea_wei.put(word, weight);
     80                 }
     81                 for(int i=0;i<feature_array.length;i++){
     82                     String feat=feature_array[i].toString();
     83                     double w=0.0;
     84                     if(fea_wei.containsKey(feat))
     85                         w=fea_wei.get(feat);
     86                     v.set(i, w);
     87                 }
     88                 
     89                 //把文档向量写入dvm路径下的文件
     90                 FileWriter fw=new FileWriter(newFile);
     91                 BufferedWriter bw=new BufferedWriter(fw);
     92                 Iterator<Double> iter=v.iterator();
     93                 while(iter.hasNext()){
     94                     bw.write(String.valueOf(iter.next()));
     95                     bw.newLine();
     96                 }
     97                 bw.flush();
     98                 bw.close();
     99             } catch (Exception e) {
    100                 e.printStackTrace();
    101             }
    102 
    103         }
    104     }
    105     public static void main(String[] args){
    106         AVSM inst=new AVSM();
    107         File feaFile=new File("/home/orisun/features"); 
    108         inst.initFeatures(feaFile);
    109         //inst.printFeature();
    110         File freqFile=new File("/home/orisun/unknown");
    111         inst.buildDVM(freqFile);
    112         if(!freqFile.exists()){
    113             System.out.println("文件不存在,程序退出.");
    114             System.exit(2);
    115         }
    116         
    117     }
    118 }

    10.运行KNN.java。计算出unknown向量和6个训练向量的夹角斜纹余弦为:

0

0

1

0

2

0

3

0.7071

4

1

5

0.7071

余弦值越大说明夹角越小,说明向量越相似。显然unknown文件与文档345很相似,所以它因该属于第2类!

View Code
  1 /**
  2  * Author: Orisun
  3  * Date: Aug 30, 2011
  4  * FileName: KNN.java
  5  * Function: K-Nearest Neighbour Algorithm
  6  */
  7 package classification;
  8 
  9 import java.io.BufferedReader;
 10 import java.io.File;
 11 import java.io.FileReader;
 12 import java.io.IOException;
 13 import java.util.ArrayList;
 14 import java.util.HashMap;
 15 import java.util.Vector;
 16 import java.util.Map.Entry;
 17 
 18 import featureselect.FS;
 19 
 20 public class KNN {
 21     //训练集文档的向量表示
 22     HashMap<String, Double> dist = new HashMap<String, Double>(); // 等分类文档互每个训练集文档的距离
 23     Vector<Double> uv = new Vector<Double>(); // 待分类文档的向量表示
 24 
 25     // srcFile存放了待分类文档的向量表示
 26     public void initUV(File srcFile) {
 27         if (!srcFile.exists()) {
 28             System.err.println("File not found.Program exit with failure.");
 29             System.exit(2);
 30         }
 31         try {
 32             FileReader fr = new FileReader(srcFile);
 33             BufferedReader br = new BufferedReader(fr);
 34             String line = null;
 35             while ((line = br.readLine()) != null) {
 36                 uv.add(Double.valueOf(line.trim()));
 37             }
 38             br.close();
 39         } catch (IOException e) {
 40             e.printStackTrace();
 41         }
 42     }
 43 
 44     // srcFile是训练集的文档向量所在的文件夹
 45     public void calDist(File srcFile) {
 46         File[] children = srcFile.listFiles();
 47         
 48         for (File child : children) {
 49             String filename = child.getName();
 50             Vector<Double> v = new Vector<Double>();
 51             try {
 52                 FileReader fr = new FileReader(child);
 53                 BufferedReader br = new BufferedReader(fr);
 54                 String line = null;
 55                 while ((line = br.readLine()) != null) {
 56                     v.add(Double.valueOf(line.trim()));
 57                 }
 58                 br.close();
 59             } catch (IOException e) {
 60                 e.printStackTrace();
 61             }
 62             int len = v.size();
 63             double d = cos(v, uv, len);
 64             dist.put(filename, d);
 65             System.out.println(""+filename+"的距离是"+d);
 66         }
 67     }
 68 
 69     // 计算两个向量的夹角的余弦。如果此值的绝对值越大,说明夹角越小,越相似,距离越近。
 70     public double cos(Vector<Double> v1, Vector<Double> v2, int len) {
 71         double res = 0.0;
 72         double mul = 0.0;
 73         double p1 = 0.0, p2 = 0.0;
 74         for (int i = 0; i < len; i++) {
 75             double one = v1.get(i);
 76             double two = v2.get(i);
 77             mul += one * two;
 78             p1 += Math.pow(one, 2);
 79             p2 += Math.pow(two, 2);
 80         }
 81         res = Math.abs(mul) / Math.sqrt(p1 * p2);
 82         return res;
 83     }
 84 
 85     public void knn(int k){
 86         //对新向量到所有训练向量的距离按从大到小排序
 87         FS fs=new FS();
 88         ArrayList<Entry<String,Double>> dist_list=fs.sort(dist);
 89         int c1=0,c2=0;
 90         for(int i=0;i<k;i++){
 91             Entry<String,Double> entry=dist_list.get(i);
 92             int fileNum=Integer.parseInt(entry.getKey());
 93             if(fileNum>=0 && fileNum<3)
 94                 c1++;
 95             else if(fileNum>=3 && fileNum<6)
 96                 c2++;
 97         }
 98         if(c1>c2)
 99             System.out.println("属于第1类!");
100         else if(c2>1)
101             System.out.println("属于第2类!");
102         else
103             System.out.println("属于两类的可能性一样大!");
104     }
105     
106     public static void main(String[] args){
107         KNN inst=new KNN();
108         File uvFile=new File("/home/orisun/dvm2/unknown");
109         inst.initUV(uvFile);
110         File tvFiles=new File("/home/orisun/dvm");
111         inst.calDist(tvFiles);
112         inst.knn(3);
113     }
114 }

以上只是简单的模拟,在实际文本分类中训练集很大时,word-doc matrix会很大,必然造成内存溢出。解决方法参见我的从原始文档到KNN分类算法实现(二)

你可能感兴趣的:(算法)