贝叶斯文本分类 java实现

  昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。 

  试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。 


  训练代码如下:  

 

  1 /**
  2  * 训练器
  3  * 
  4  * <a href="http://my.oschina.net/arthor" target="_blank" rel="nofollow">@author</a>  duyf
  5  * 
  6  */
  7 class Train implements Serializable {
  8 
  9     /**
 10      * 
 11      */
 12     private static final long serialVersionUID = 1L;
 13 
 14     public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
 15     // 训练集的位置
 16     private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
 17 
 18     // 类别序号对应的实际名称
 19     private Map<String, String> classMap = new HashMap<String, String>();
 20 
 21     // 类别对应的txt文本数
 22     private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();
 23 
 24     // 所有文本数
 25     private AtomicInteger actCount = new AtomicInteger(0);
 26 
 27     
 28 
 29     // 每个类别对应的词典和频数
 30     private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();
 31 
 32     // 分词器
 33     private transient Participle participle;
 34 
 35     private static Train trainInstance = new Train();
 36 
 37     public static Train getInstance() {
 38         trainInstance = new Train();
 39 
 40         // 读取序列化在硬盘的本类对象
 41         FileInputStream fis;
 42         try {
 43             File f = new File(SERIALIZABLE_PATH);
 44             if (f.length() != 0) {
 45                 fis = new FileInputStream(SERIALIZABLE_PATH);
 46                 ObjectInputStream oos = new ObjectInputStream(fis);
 47                 trainInstance = (Train) oos.readObject();
 48                 trainInstance.participle = new IkParticiple();
 49             } else {
 50                 trainInstance = new Train();
 51             }
 52         } catch (Exception e) {
 53             e.printStackTrace();
 54         }
 55 
 56         return trainInstance;
 57     }
 58 
 59     private Train() {
 60         this.participle = new IkParticiple();
 61     }
 62 
 63     public String readtxt(String path) {
 64         BufferedReader br = null;
 65         StringBuilder str = null;
 66         try {
 67             br = new BufferedReader(new FileReader(path));
 68 
 69             str = new StringBuilder();
 70 
 71             String r = br.readLine();
 72 
 73             while (r != null) {
 74                 str.append(r);
 75                 r = br.readLine();
 76 
 77             }
 78 
 79             return str.toString();
 80         } catch (IOException ex) {
 81             ex.printStackTrace();
 82         } finally {
 83             if (br != null) {
 84                 try {
 85                     br.close();
 86                 } catch (IOException e) {
 87                     e.printStackTrace();
 88                 }
 89             }
 90             str = null;
 91             br = null;
 92         }
 93 
 94         return "";
 95     }
 96 
 97     /**
 98      * 训练数据
 99      */
100     public void realTrain() {
101         // 初始化
102         classMap = new HashMap<String, String>();
103         classP = new HashMap<String, Integer>();
104         actCount.set(0);
105         classWordMap = new HashMap<String, Map<String, Double>>();
106 
107         // classMap.put("C000007", "汽车");
108         classMap.put("C000008", "财经");
109         classMap.put("C000010", "IT");
110         classMap.put("C000013", "健康");
111         classMap.put("C000014", "体育");
112         classMap.put("C000016", "旅游");
113         classMap.put("C000020", "教育");
114         classMap.put("C000022", "招聘");
115         classMap.put("C000023", "文化");
116         classMap.put("C000024", "军事");
117 
118         // 计算各个类别的样本数
119         Set<String> keySet = classMap.keySet();
120 
121         // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
122         final Set<String> allWords = new HashSet<String>();
123 
124         // 存放每个类别的文件词汇内容
125         final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();
126 
127         for (String classKey : keySet) {
128 
129             Participle participle = new IkParticiple();
130             Map<String, Double> wordMap = new HashMap<String, Double>();
131             File f = new File(trainPath + File.separator + classKey);
132             File[] files = f.listFiles(new FileFilter() {
133 
134                 @Override
135                 public boolean accept(File pathname) {
136                     if (pathname.getName().endsWith(".txt")) {
137                         return true;
138                     }
139                     return false;
140                 }
141 
142             });
143 
144             // 存储每个类别的文件词汇向量
145             List<String[]> fileContent = new ArrayList<String[]>();
146             if (files != null) {
147                 for (File txt : files) {
148                     String content = readtxt(txt.getAbsolutePath());
149                     // 分词
150                     String[] word_arr = participle.participle(content, false);
151                     fileContent.add(word_arr);
152                     // 统计每个词出现的个数
153                     for (String word : word_arr) {
154                         if (wordMap.containsKey(word)) {
155                             Double wordCount = wordMap.get(word);
156                             wordMap.put(word, wordCount + 1);
157                         } else {
158                             wordMap.put(word, 1.0);
159                         }
160                         
161                     }
162                 }
163             }
164 
165             // 每个类别对应的词典和频数
166             classWordMap.put(classKey, wordMap);
167 
168             // 每个类别的文章数目
169             classP.put(classKey, files.length);
170             actCount.addAndGet(files.length);
171             classContentMap.put(classKey, fileContent);
172 
173         }
174 
175         
176 
177         
178 
179         // 把训练好的训练器对象序列化到本地 (空间换时间)
180         FileOutputStream fos;
181         try {
182             fos = new FileOutputStream(SERIALIZABLE_PATH);
183             ObjectOutputStream oos = new ObjectOutputStream(fos);
184             oos.writeObject(this);
185         } catch (Exception e) {
186             e.printStackTrace();
187         }
188 
189     }
190 
191     /**
192      * 分类
193      * 
194      * @param text
195      * <a href="http://my.oschina.net/u/556800" target="_blank" rel="nofollow">@return</a>  返回各个类别的概率大小
196      */
197     public Map<String, Double> classify(String text) {
198         // 分词,并且去重
199         String[] text_words = participle.participle(text, false);
200 
201         Map<String, Double> frequencyOfType = new HashMap<String, Double>();
202         Set<String> keySet = classMap.keySet();
203         for (String classKey : keySet) {
204             double typeOfThis = 1.0;
205             Map<String, Double> wordMap = classWordMap.get(classKey);
206             for (String word : text_words) {
207                 Double wordCount = wordMap.get(word);
208                 int articleCount = classP.get(classKey);
209 
210                 /*
211                  * Double wordidf = idfMap.get(word); if(wordidf==null){
212                  * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
213                  */
214 
215                 // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
216                 double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
217                         : (wordCount / articleCount);
218 
219                 // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
220                 // 当double无限小的时候会归为0,为了避免 *10
221 
222                 typeOfThis = typeOfThis * term_frequency * 10;
223                 typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
224                         : typeOfThis);
225                 // System.out.println(typeOfThis+" : "+term_frequency+" :
226                 // "+actCount);
227             }
228 
229             typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
230 
231             // 此类别文章出现的概率
232             double classOfAll = classP.get(classKey) / actCount.doubleValue();
233 
234             // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
235             frequencyOfType.put(classKey, typeOfThis * classOfAll);
236         }
237 
238         return frequencyOfType;
239     }
240 
241     public void pringAll() {
242         Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap
243                 .entrySet();
244         for (Entry<String, Map<String, Double>> ent : classWordEntry) {
245             System.out.println("类别: " + ent.getKey());
246             Map<String, Double> wordMap = ent.getValue();
247             Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();
248             for (Entry<String, Double> wordEnt : wordMapSet) {
249                 System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
250             }
251         }
252     }
253 
254     public Map<String, String> getClassMap() {
255         return classMap;
256     }
257 
258     public void setClassMap(Map<String, String> classMap) {
259         this.classMap = classMap;
260     }
261 
262 }

 

 

 

 

  在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。 

 

 

来源:http://my.oschina.net/duyunfei/blog/80283

你可能感兴趣的:(java实现)