算法来自于《集体智慧编程》-第六章
原书代码用 Python 实现,这两天看这章书,改用 Java 实现。
package ch6DocumentFiltering; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Set; public class Classifier { private HashMap<String, Integer[]> fc = new HashMap<String, Integer[]>(); private HashMap<String, Integer> cc = new HashMap<String, Integer>(); private HashMap<String, Integer> catMap = new HashMap<String, Integer>(); private HashMap<String, Double> threshold = new HashMap<String, Double>(); // private String[] features; public Classifier() { // fc = null; // cc = null; // this.features = getFeatures(content); } public double getThreshold(String key) { if(this.threshold.get(key) == null) return 1.0; return this.threshold.get(key); } public void setThreshold(String key, double t) { this.threshold.put(key, t); } /** * * @param content * @return */ public String[] getFeatures(String content) { DocClass doc = new DocClass(); return doc.getWords(content); } /** * 增加某一分类的计数值 * * @param self * @param key * @param cat * @return */ public void infc(String key, int cat) { if (this.fc.get(key) != null) { Integer[] temp = this.fc.get(key); if (temp.length > cat) { Integer[] result = new Integer[temp.length]; for (int i = 0; i < temp.length; i++) { if (i == cat) result[i] = temp[i] + 1; else result[i] = temp[i]; } this.fc.put(key, result); } else { Integer[] result = new Integer[cat + 1]; for (int i = 0; i < temp.length; i++) { result[i] = temp[i]; } for (int j = temp.length; j < cat + 1; j++) { result[j] = 1; } this.fc.put(key, result); } } else { Integer[] result = new Integer[cat + 1]; result[cat] = 1; this.fc.put(key, result); } } public void incc(String key) { if (this.cc.get(key) != null) this.cc.put(key, this.cc.get(key) + 1); else this.cc.put(key, 1); } /** * 某一特征出现在某分类中的次数 * * @param key * @param cat * @return */ public double fcount(String key, int cat) { if (this.fc.get(key) != null) { Integer[] temp = this.fc.get(key); if (temp.length > cat) { if (this.fc.get(key)[cat] != null) return (double) this.fc.get(key)[cat]; } } return 0.0; } /** * 某一分类的内容项数量 * * @param cat * @return */ public int catCount(String cat) { if (this.cc.get(cat) != null) return this.cc.get(cat); return 0; } /** * 所有内容项数量 * * @return */ public int totalCount() { int count = 0; for(Iterator<String> i = this.cc.keySet().iterator(); i.hasNext();){ String key = i.next(); count += this.cc.get(key); } return count; } /** * 分类列表 * * @return */ public Set<String> getCategories() { return this.cc.keySet(); } /** * * @param item * @param cat */ public void train(String item, String cat) { String[] features = getFeatures(item); int intCat = -1; if (this.catMap.get(cat) == null) { this.catMap.put(cat, new Integer(this.catMap.size())); } intCat = this.catMap.get(cat); for (String f : features) { this.infc(f, intCat); } this.incc(cat); } /** * 计算单词咋分类中出现的概率 * * @param key * @param cat * @return */ public double fprob(String key, String cat) { if (this.catCount(cat) == 0) return 0.0; return this.fcount(key, this.catMap.get(cat)) / (double)this.catCount(cat); } /** * * @param key * @param cat * @param weight * @param ap * @return */ public double weightedProb(String key, String cat, double weight, double ap) { double basicProb = fprob(key, cat); int totals = 0; String[] cats = this.getCategories().toArray(new String[0]); for (String c : cats) { totals = (int) (totals + this.fcount(key, this.catMap.get(c))); } double bp = ((weight * ap) + (totals * basicProb)) / (weight + totals); return bp; } /** * 找出最可能的分类 * * @param item * @param defaultCat * @return */ public String classify(String item, String defaultCat){ String best = defaultCat; double max = 0.0; HashMap<String, Double> probs = new HashMap<String, Double>(); Naivebayes n = new Naivebayes(); for(Iterator<String> i = this.getCategories().iterator(); i.hasNext();){ String cat = i.next(); probs.put(cat, n.prob(this, item, cat)); if(probs.get(cat) > max){ max = probs.get(cat); best = cat; } } for(Iterator<String> i = probs.keySet().iterator(); i.hasNext(); ){ String cat = i.next(); if(cat == best) continue; if(probs.get(cat)*this.getThreshold(best) > probs.get(best)) return defaultCat; } return best; } /** * * @param c1 */ public void sampleTrain() { this.train("the quick brown fox jumps over the lazy dog", "good"); this.train("make quick monkey in the online casino", "bad"); this.train("Nobody owns the water.", "good"); this.train("the quick rabbit jumps fences", "good"); this.train("buy pharmaceuticals now", "bad"); } public static void main(String[] args) { Classifier c1 = new Classifier(); c1.sampleTrain(); // c1.setThreshold("bad", 3); // System.out.println(c1.classify("quick monkey", "unknow")); Fisher fisher = new Fisher(); // System.out.println(fisher.cprob(c1, "money", "bad")); System.out.println(fisher.fisherProb(c1, "quick rabbit", "bad")); } }
package ch6DocumentFiltering;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class DocClass {
private Pattern p = Pattern.compile("\\w*");
/**
* 得到一组文件中包含的不重复单词
*
* @param content
* @return
*/
public String[] getWords(String content) {
String[] dict;
HashMap<String, Integer> wordsMap = new HashMap<String, Integer>();
Matcher m = p.matcher(content);
while (m.find()) {
int start = m.start();
int end = m.end();
String word = content.substring(start, end).toLowerCase();
if (word.length() < 20 && word.length() > 2) {
Integer count = wordsMap.get(word);
if (count == null)
count = 1;
else
count += 1;
wordsMap.put(word, count);
}
}
Set<String> wordsSet = wordsMap.keySet();
dict = wordsSet.toArray(new String[0]);
return dict;
}
public static void main(String[] args) {
String sample = "A wiki ( /ˈwɪki/ WIK-ee) is a website that allows the easy[1] creation and editing of any number of interlinked web pages via a web browser using a simplified markup language or a WYSIWYG text editor.";
String[] result = new DocClass().getWords(sample);
for (String word : result) {
System.out.println(word);
}
}
}
package ch6DocumentFiltering; public class Naivebayes { /** * pr(Document|Category) * * @param c * @param item * @param cat * @return */ public double docProb(Classifier c, String item, String cat) { String[] features = c.getFeatures(item); double p = 1.0; for (String f : features) { p *= c.weightedProb(f, cat, 1.0, 0.5); } return p; } /** * pr(Category|Document)*pr(Document) * @param c * @param item * @param cat * @return */ public double prob(Classifier c, String item, String cat){ double catProb = ((double)c.catCount(cat)/c.totalCount()); //System.out.println(c.totalCount()); double douDocProb = docProb(c, item, cat); return catProb*douDocProb; } public static void main(String[] args){ Classifier c = new Classifier(); c.sampleTrain(); Naivebayes n = new Naivebayes(); System.out.println(n.prob(c, "quick monkey", "good")); System.out.println(n.prob(c, "rabbit", "bad")); } }
package ch6DocumentFiltering; import java.util.HashMap; import java.util.Iterator; public class Fisher { private HashMap<String, Double> minimum = new HashMap<String, Double>(); public double getMinimum(String cat) { if(this.minimum.get(cat) == null) return 0.0; return this.minimum.get(cat); } public void setMinimum(String cat, double min) { this.minimum.put(cat, new Double(min)); } /** * * @param c * @param f * @param cat * @return */ public double cprob(Classifier c, String f, String cat) { // 该特征在某分类中出现的概率 double clf = c.fprob(f, cat); if (clf == 0) return 0; // 该特征在所有分类中出现的概率之和 double freqSum = 0; for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) { String catTemp = i.next(); freqSum += c.fprob(f, catTemp); } return clf / freqSum; } /** * * @param c * @param item * @param cat * @return */ public double fisherProb(Classifier c, String item, String cat) { double p = 1.0; String[] features = c.getFeatures(item); for (String f : features) { p *= c.weightedProb(f, cat, 1.0, 0.5); } double fScore = -2 * Math.log(p); return invchi2(fScore, 2 * features.length); } /** * 倒置<a href = "http://baike.baidu.com/view/859454.htm">对数卡方</a>函数 * * @param chi * @param df * @return */ public double invchi2(double chi, double df) { double m = chi / 2.5; double sum, term; sum = term = Math.exp(-m); int temp = (int) (df / 2); for (int i = 1; i < temp; i++) { term *= m / i; sum += term; } return Math.min(sum, 1.0); } /** * * @param c * @param item * @param defaultCat * @return */ public String classify(Classifier c, String item, String defaultCat) { String best = defaultCat; double max = 0.0; for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) { String catTemp = i.next(); double p = this.fisherProb(c, item, catTemp); if (p > this.getMinimum(catTemp) && p > max) { best = catTemp; max = p; } } System.out.println(max); return best; } public static void main(String[] args){ Classifier c = new Classifier(); Fisher fisher = new Fisher(); fisher.setMinimum("bad", 0.8); fisher.setMinimum("good", 0.4); c.sampleTrain(); System.out.println(fisher.classify(c, "casino", "none")); } }