文档过滤

算法来自于《集体智慧编程》-第六章

原书代码用 Python 实现,这两天看这章书,改用 Java 实现。

 

package ch6DocumentFiltering;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;

public class Classifier {
	private HashMap<String, Integer[]> fc = new HashMap<String, Integer[]>();
	private HashMap<String, Integer> cc = new HashMap<String, Integer>();
	private HashMap<String, Integer> catMap = new HashMap<String, Integer>();
	private HashMap<String, Double> threshold = new HashMap<String, Double>();

	// private String[] features;

	public Classifier() {
		// fc = null;
		// cc = null;
		// this.features = getFeatures(content);
	}

	public double getThreshold(String key) {
		if(this.threshold.get(key) == null)
			return 1.0;
		return this.threshold.get(key);
	}

	public void setThreshold(String key, double t) {
		this.threshold.put(key, t);
	}

	/**
	 * 
	 * @param content
	 * @return
	 */
	public String[] getFeatures(String content) {
		DocClass doc = new DocClass();
		return doc.getWords(content);
	}

	/**
	 * 增加某一分类的计数值
	 * 
	 * @param self
	 * @param key
	 * @param cat
	 * @return
	 */
	public void infc(String key, int cat) {
		if (this.fc.get(key) != null) {
			Integer[] temp = this.fc.get(key);
			if (temp.length > cat) {
				Integer[] result = new Integer[temp.length];
				for (int i = 0; i < temp.length; i++) {
					if (i == cat)
						result[i] = temp[i] + 1;
					else
						result[i] = temp[i];
				}
				this.fc.put(key, result);
			} else {
				Integer[] result = new Integer[cat + 1];
				for (int i = 0; i < temp.length; i++) {
					result[i] = temp[i];
				}
				for (int j = temp.length; j < cat + 1; j++) {
					result[j] = 1;
				}
				this.fc.put(key, result);
			}
		} else {
			Integer[] result = new Integer[cat + 1];
			result[cat] = 1;
			this.fc.put(key, result);
		}
	}

	public void incc(String key) {
		if (this.cc.get(key) != null)
			this.cc.put(key, this.cc.get(key) + 1);
		else
			this.cc.put(key, 1);
	}

	/**
	 * 某一特征出现在某分类中的次数
	 * 
	 * @param key
	 * @param cat
	 * @return
	 */
	public double fcount(String key, int cat) {
		if (this.fc.get(key) != null) {
			Integer[] temp = this.fc.get(key);
			if (temp.length > cat) {
				if (this.fc.get(key)[cat] != null)
					return (double) this.fc.get(key)[cat];
			}
		}
		return 0.0;
	}

	/**
	 * 某一分类的内容项数量
	 * 
	 * @param cat
	 * @return
	 */
	public int catCount(String cat) {
		if (this.cc.get(cat) != null)
			return this.cc.get(cat);
		return 0;
	}

	/**
	 * 所有内容项数量
	 * 
	 * @return
	 */
	public int totalCount() {
		int count = 0;
		for(Iterator<String> i = this.cc.keySet().iterator(); i.hasNext();){
			String key = i.next();
			count += this.cc.get(key);
		}
		return count;
	}

	/**
	 * 分类列表
	 * 
	 * @return
	 */
	public Set<String> getCategories() {
		return this.cc.keySet();
	}

	/**
	 * 
	 * @param item
	 * @param cat
	 */
	public void train(String item, String cat) {
		String[] features = getFeatures(item);

		int intCat = -1;
		if (this.catMap.get(cat) == null) {
			this.catMap.put(cat, new Integer(this.catMap.size()));
		}
		intCat = this.catMap.get(cat);

		for (String f : features) {
			this.infc(f, intCat);
		}

		this.incc(cat);
	}

	/**
	 * 计算单词咋分类中出现的概率
	 * 
	 * @param key
	 * @param cat
	 * @return
	 */
	public double fprob(String key, String cat) {
		if (this.catCount(cat) == 0)
			return 0.0;
		return this.fcount(key, this.catMap.get(cat)) / (double)this.catCount(cat);
	}

	/**
	 * 
	 * @param key
	 * @param cat
	 * @param weight
	 * @param ap
	 * @return
	 */
	public double weightedProb(String key, String cat, double weight, double ap) {
		double basicProb = fprob(key, cat);

		int totals = 0;
		String[] cats = this.getCategories().toArray(new String[0]);
		for (String c : cats) {
			totals = (int) (totals + this.fcount(key, this.catMap.get(c)));
		}

		double bp = ((weight * ap) + (totals * basicProb)) / (weight + totals);
		return bp;
	}
	
	/**
	 * 找出最可能的分类
	 * 
	 * @param item
	 * @param defaultCat
	 * @return
	 */
	public String classify(String item, String defaultCat){
		String best = defaultCat;
		double max = 0.0;
		HashMap<String, Double> probs = new HashMap<String, Double>();
		Naivebayes n = new Naivebayes();
		for(Iterator<String> i = this.getCategories().iterator(); i.hasNext();){
			String cat = i.next();
			probs.put(cat, n.prob(this, item, cat));
			if(probs.get(cat) > max){
				max = probs.get(cat);
				best = cat;
			}
		}
		
		for(Iterator<String> i = probs.keySet().iterator(); i.hasNext(); ){
			String cat = i.next();
			if(cat == best) continue;
			if(probs.get(cat)*this.getThreshold(best) > probs.get(best))
				return defaultCat;
		}
		return best;
	}

	/**
	 * 
	 * @param c1
	 */
	public void sampleTrain() {
		this.train("the quick brown fox jumps over the lazy dog", "good");
		this.train("make quick monkey in the online casino", "bad");
		this.train("Nobody owns the water.", "good");
		this.train("the quick rabbit jumps fences", "good");
		this.train("buy pharmaceuticals now", "bad");
	}

	public static void main(String[] args) {
		Classifier c1 = new Classifier();
		c1.sampleTrain();
//		c1.setThreshold("bad", 3);
//		System.out.println(c1.classify("quick monkey", "unknow"));
		
		Fisher fisher = new Fisher();
//		System.out.println(fisher.cprob(c1, "money", "bad"));
		System.out.println(fisher.fisherProb(c1, "quick rabbit", "bad"));
	}
}

package ch6DocumentFiltering;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class DocClass {
	private Pattern p = Pattern.compile("\\w*");

	/**
	 * 得到一组文件中包含的不重复单词
	 * 
	 * @param content
	 * @return
	 */
	public String[] getWords(String content) {
		String[] dict;
		HashMap<String, Integer> wordsMap = new HashMap<String, Integer>();
		Matcher m = p.matcher(content);
		while (m.find()) {
			int start = m.start();
			int end = m.end();
			String word = content.substring(start, end).toLowerCase();
			if (word.length() < 20 && word.length() > 2) {
				Integer count = wordsMap.get(word);
				if (count == null)
					count = 1;
				else
					count += 1;
				wordsMap.put(word, count);
			}
		}
		Set<String> wordsSet = wordsMap.keySet();
		dict = wordsSet.toArray(new String[0]);
		return dict;
	}

	public static void main(String[] args) {
		String sample = "A wiki ( /ˈwɪki/ WIK-ee) is a website that allows the easy[1] creation and editing of any number of interlinked web pages via a web browser using a simplified markup language or a WYSIWYG text editor.";
		String[] result = new DocClass().getWords(sample);
		for (String word : result) {
			System.out.println(word);
		}
	}
}

 

package ch6DocumentFiltering;

public class Naivebayes {
	/**
	 * pr(Document|Category)
	 * 
	 * @param c
	 * @param item
	 * @param cat
	 * @return
	 */
	public double docProb(Classifier c, String item, String cat) {
		String[] features = c.getFeatures(item);
		double p = 1.0;
		for (String f : features) {
			p *= c.weightedProb(f, cat, 1.0, 0.5);
		}
		return p;
	}
	
	/**
	 * pr(Category|Document)*pr(Document)
	 * @param c
	 * @param item
	 * @param cat
	 * @return
	 */
	public double prob(Classifier c, String item, String cat){
		double catProb = ((double)c.catCount(cat)/c.totalCount());
		//System.out.println(c.totalCount());
		double douDocProb = docProb(c, item, cat);
		return catProb*douDocProb;
	}
	
	public static void main(String[] args){
		Classifier c = new Classifier();
		c.sampleTrain();
		Naivebayes n = new Naivebayes();
		System.out.println(n.prob(c, "quick monkey", "good"));
		System.out.println(n.prob(c, "rabbit", "bad"));
	}
}

 

package ch6DocumentFiltering;

import java.util.HashMap;
import java.util.Iterator;

public class Fisher {
	private HashMap<String, Double> minimum = new HashMap<String, Double>();

	public double getMinimum(String cat) {
		if(this.minimum.get(cat) == null)
			return 0.0;
		return this.minimum.get(cat);
	}

	public void setMinimum(String cat, double min) {
		this.minimum.put(cat, new Double(min));
	}

	/**
	 * 
	 * @param c
	 * @param f
	 * @param cat
	 * @return
	 */
	public double cprob(Classifier c, String f, String cat) {
		// 该特征在某分类中出现的概率
		double clf = c.fprob(f, cat);
		if (clf == 0)
			return 0;

		// 该特征在所有分类中出现的概率之和
		double freqSum = 0;
		for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) {
			String catTemp = i.next();
			freqSum += c.fprob(f, catTemp);
		}

		return clf / freqSum;
	}

	/**
	 * 
	 * @param c
	 * @param item
	 * @param cat
	 * @return
	 */
	public double fisherProb(Classifier c, String item, String cat) {
		double p = 1.0;
		String[] features = c.getFeatures(item);
		for (String f : features) {
			p *= c.weightedProb(f, cat, 1.0, 0.5);
		}
		double fScore = -2 * Math.log(p);

		return invchi2(fScore, 2 * features.length);
	}

	/**
	 * 倒置<a href = "http://baike.baidu.com/view/859454.htm">对数卡方</a>函数
	 * 
	 * @param chi
	 * @param df
	 * @return
	 */
	public double invchi2(double chi, double df) {
		double m = chi / 2.5;
		double sum, term;
		sum = term = Math.exp(-m);
		int temp = (int) (df / 2);
		for (int i = 1; i < temp; i++) {
			term *= m / i;
			sum += term;
		}

		return Math.min(sum, 1.0);
	}

	/**
	 * 
	 * @param c
	 * @param item
	 * @param defaultCat
	 * @return
	 */
	public String classify(Classifier c, String item, String defaultCat) {
		String best = defaultCat;
		double max = 0.0;
		for (Iterator<String> i = c.getCategories().iterator(); i.hasNext();) {
			String catTemp = i.next();
			double p = this.fisherProb(c, item, catTemp);
			if (p > this.getMinimum(catTemp) && p > max) {
				best = catTemp;
				max = p;
			}
		}
		System.out.println(max);
		return best;
	}
	
	public static void main(String[] args){
		Classifier c = new Classifier();
		Fisher fisher = new Fisher();
		fisher.setMinimum("bad", 0.8);
		fisher.setMinimum("good", 0.4);
		c.sampleTrain();
		System.out.println(fisher.classify(c, "casino", "none"));
	}
}
 

你可能感兴趣的:(C++,c,python,C#,F#)