Mahout垃圾邮件过滤模型demo

一、准备工作

  1. 安装Apache Mahout

    Apache Mahout的安装非常简单。只要下载压缩包解压就可以了。本文使用的是Apache Mahout 0.10.0,hadoop版本2.5.2

  2. 获取训练数据

    地址:http://spamassassin.apache.org/publiccorpus/

    这个是Apache Spam Assassin项目下的数据

二、训练模型

  1.     将数据放到hdfs上,如/tmp/train。目录里分别包含垃圾邮件目录spam和非垃圾邮件目录ham。

  2.    下面是模型训练脚本。脚本参考了mahout-distribution-0.10.0/examples/bin/classify-20newsgroups.sh。

#!/bin/sh

mahout seqdirectory -i /tmp/train -o /tmp/train-seq -ow && \

hadoop fs  -rmr /tmp/train-vectors && \
mahout seq2sparse -i /tmp/train-seq -o /tmp/train-vectors -lnorm -nv  -wt tfidf

if [ $? -ne 0 ]; then
exit 1
fi

echo "=============================="
echo "seq2sparse finished"
echo "=============================="

mahout split -i /tmp/train-vectors/tfidf-vectors \
--trainingOutput /tmp/antispam-train-vectors \
--testOutput /tmp/antispam-test-vectors \
--randomSelectionPct 40 --overwrite --sequenceFiles -xm sequential && \

mahout trainnb -i /tmp/antispam-train-vectors -o /tmp/model -li /tmp/labelindex -ow -c
echo "=============================="
echo "trainnb finished"
echo "=============================="

三、搭建邮件过滤web server

1.    构造classifier

package com.melot.antispam;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;

/*
 * 		public class Classifier
 * 
 * 		@desc Classifies tweets using a model trained in mahout.
 * 
 * 		@author Vicente Ruben Del Pino Ruiz <<[email protected]>>
 * 
 */
public class Classifier {

	private AbstractVectorClassifier classifier;
	private NaiveBayesModel naiveBayesModel;
	private Map<String, Integer> dictionary;
	private Map<Integer, Long> documentFrequency;
	private int documentCount;
	private Map<Integer, String> labelIndex;

	public final static String DICTIONARY_PATH_CONF = "dictionaryPath";
	public final static String DOCUMENT_FREQUENCY_PATH_CONF = "documentFrequencyPath";
	public final static String DOCUMENT_PATH_CONF = "documentPath";
	public final static String MODEL_PATH_CONF = "modelPath";
	public final static String LABEL_INDEX = "labelIndex";

	/*
	 * public Classifier
	 * 
	 * @desc Initialize all the variables to be used and reads: --Dictionnary or
	 * train set. --Document Frequency of train set. --Naive Bayes model. --
	 * 
	 * @param Configuration configuration. Configuration of the cluster.
	 */
	public Classifier(Configuration configuration) throws IOException {

		String modelPath = configuration.getStrings(MODEL_PATH_CONF)[0];
		String dictionaryPath = configuration.getStrings(DICTIONARY_PATH_CONF)[0];
		String documentFrequencyPath = configuration
				.getStrings(DOCUMENT_FREQUENCY_PATH_CONF)[0];
		String documentPath = configuration.getStrings(DOCUMENT_PATH_CONF)[0];
		String labelIndexPath = configuration.getStrings(LABEL_INDEX)[0];

		dictionary = readDictionnary(configuration, new Path(dictionaryPath));
		documentFrequency = readDocumentFrequency(configuration, new Path(
				documentFrequencyPath));
		documentCount = readDocumentCount(configuration, new Path(documentPath));
		labelIndex = readLabelIndex(configuration, new Path(labelIndexPath));

		naiveBayesModel = NaiveBayesModel.materialize(new Path(modelPath),
				configuration);
		classifier = new StandardNaiveBayesClassifier(naiveBayesModel);

	}

	/*
	 * private int getBestCategory
	 * 
	 * @desc Gets the best category for a vector of scores
	 * 
	 * @param Vector result. Vector with scores
	 * 
	 * @return bestCategoryID. Best category for the vector
	 */

	private int getBestCategory(Vector result) {

		// Iterate through the scores and take the category with the higher.
		double bestScore = -Double.MAX_VALUE;
		int bestCategoryId = -1;

		for (Element element : result.all()) {
			int categoryId = element.index();
			double score = element.get();

			if (score > bestScore) {
				bestScore = score;
				bestCategoryId = categoryId;
			}
		}

		return bestCategoryId;
	}

	/*
	 * private Vector generateTFIDFVector
	 * 
	 * @desc Generates a TFIDF vector for words
	 * 
	 * @param HasMap words. Words to use for generating the vector
	 * 
	 * @param int wordCount. Number of words
	 * 
	 * @param int documentCount. Document counting
	 * 
	 * @return bestCategoryID. Best category for the vector
	 */

	private Vector generateTFIDFVector(HashMap<String, Integer> words,
			int wordCount, int documentCount) {

		Vector vector = new RandomAccessSparseVector(10000);
		TFIDF tfidf = new TFIDF();

		// Create a TF-IDF vector for each tweet

		for (Entry<String, Integer> entry : words.entrySet()) {

			String word = entry.getKey();
			int count = entry.getValue();
			Integer wordId = dictionary.get(word);
			Long freq = documentFrequency.get(wordId);
			double tfIdfValue = tfidf.calculate(count, freq.intValue(),
					wordCount, documentCount);
			vector.setQuick(wordId, tfIdfValue);

		}

		return vector;
	}

	/*
	 * public int classify
	 * 
	 * @desc Main part of the class. Classifies a text using Naive Bayes and a
	 * model created by Mahout.
	 * 
	 * @param String text. Text of the tweet to classify
	 */
	public String classify(String text) throws IOException {

		HashMap<String, Integer> words = new HashMap<String, Integer>();
		int bestCategoryID;

		// Create our own TF-IDF vector with the tweet text
		String delims = "[ ]+";
		String[] ts = text.split(delims);

		int wordCount = 0;

		// Iterate through each word in the tweet and calculate its counting.
		for (int i = 0; i < ts.length; i++) {

			String word = ts[i];
			Integer wordId = dictionary.get(word);
			// Only take words that are in our train set
			if (wordId != null) {
				if (!words.containsKey(word)) {
					words.put(word, 1);
				} else {
					int countWord = words.get(word) + 1;
					words.put(word, countWord);
				}
				wordCount++;
			}
		}

		// Generate TFIDF vector
		Vector vector = generateTFIDFVector(words, wordCount, documentCount);

		// Classify the TF-IDF vector created using Mahout model
		Vector result = classifier.classifyFull(vector);

		// Get the best category for the vector
		bestCategoryID = getBestCategory(result);

		return labelIndex.get(bestCategoryID);

	}

	/*
	 * private static readDictionary
	 * 
	 * @desc Reads the dictionnary and loads in memory. With bigs train sets
	 * this consumes a high volume of memory.
	 * 
	 * @param Configuration Conf. Configuration from the cluster
	 * 
	 * @param Path dictionnaryPath. Path to the dictionnary file.
	 */

	private Map<String, Integer> readDictionnary(Configuration conf,
			Path dictionnaryPath) {
		Map<String, Integer> dictionnary = new HashMap<String, Integer>();

		for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(
				dictionnaryPath, true, conf)) {
			dictionnary.put(pair.getFirst().toString(), pair.getSecond().get());
		}
		return dictionnary;

	}

	/*
	 * private static readDocumentFrequency
	 * 
	 * @desc Reads the document frequency and loads in memory. With bigs train
	 * sets this consumes a high volume of memory.
	 * 
	 * @param Configuration Conf. Configuration from the cluster
	 * 
	 * @param Path documentFrequencyPath. Path to the document frequency file.
	 */
	private Map<Integer, Long> readDocumentFrequency(Configuration conf,
			Path documentFrequencyPath) {

		Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>();

		for (Pair<IntWritable, LongWritable> pair : new SequenceFileIterable<IntWritable, LongWritable>(
				documentFrequencyPath, true, conf)) {
			documentFrequency
					.put(pair.getFirst().get(), pair.getSecond().get());

		}

		return documentFrequency;
	}

	private Map<Integer, String> readLabelIndex(Configuration conf,
			Path labelIndexPath) {
		Map<Integer, String> labelIndex = new HashMap<Integer, String>();

		for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(
				labelIndexPath, true, conf)) {
			labelIndex.put(pair.getSecond().get(), pair.getFirst().toString());
		}

		return labelIndex;
	}

	private int readDocumentCount(Configuration conf, Path documentPath)
			throws IOException {
		FileSystem fs = FileSystem.get(conf);
		FileStatus[] files = fs.listStatus(documentPath);
		int documentCount = 0;
		for (FileStatus file : files) {
			if (file.isDirectory()) {
				documentCount += readDocumentCount(conf, file.getPath());
			} else {
				documentCount += 1;
			}
		}

		return documentCount;
	}
}

2.    创建一个简单的HttpServlet

package com.melot.antispam;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.hadoop.conf.Configuration;

@SuppressWarnings("serial")
public class AntiSpamServlet extends HttpServlet {
	private Classifier classifier;

	@Override
	public void init() {
		Configuration conf = new Configuration();
		conf.set("fs.defaultFS", Config.hdfsURI);
		conf.set(Classifier.DICTIONARY_PATH_CONF, Config.dictionaryPath);
		conf.set(Classifier.DOCUMENT_FREQUENCY_PATH_CONF,
				Config.documentFrequencyPath);
		conf.set(Classifier.DOCUMENT_PATH_CONF, Config.documentPath);
		conf.set(Classifier.MODEL_PATH_CONF, Config.modelPath);
		conf.set(Classifier.LABEL_INDEX, Config.labelIndexPath);

		try {
			classifier = new Classifier(conf);
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}

	@Override
	protected void doPost(HttpServletRequest req, HttpServletResponse resp)
			throws ServletException, IOException {

		String email = readAll(req.getReader());
		long t0 = System.currentTimeMillis();
		String category = classifier.classify(email);
		long t1 = System.currentTimeMillis();

		resp.getWriter().print(
				String.format("{\"category\":\"%s\", \"time\": %d}", category,
						t1 - t0));

	}

	@Override
	protected void doGet(HttpServletRequest req, HttpServletResponse resp)
			throws ServletException, IOException {

		String email = readAll(req.getReader());
		long t0 = System.currentTimeMillis();
		String category = classifier.classify(email);
		long t1 = System.currentTimeMillis();

		resp.getWriter().print(
				String.format("{\"category\":\"%s\", \"time\": %d}", category,
						t1 - t0));

	}

	private String readAll(Reader reader) throws IOException {

		BufferedReader bufReader = new BufferedReader(reader);

		String line;
		StringBuilder sb = new StringBuilder();
		while ((line = bufReader.readLine()) != null) {
			sb.append(line).append(" ");
		}

		return sb.toString();
	}
}

详细代码可以参考我在OSChina上的git:https://git.oschina.net/moriarty279/antispam.git

四、参考资料

  1.     针对旧版本mahout的垃圾邮件过滤模型例子——https://emmaespina.wordpress.com/2011/04/26/ham-spam-and-elephants-or-how-to-build-a-spam-filter-server-with-mahout/

  2.     主要代码来自https://github.com/MovieTrender/TwitterClassifier.git,但是它获取documentCount的方式有问题

你可能感兴趣的:(Mahout垃圾邮件过滤模型demo)