一、准备工作
安装Apache Mahout
Apache Mahout的安装非常简单。只要下载压缩包解压就可以了。本文使用的是Apache Mahout 0.10.0,hadoop版本2.5.2
获取训练数据
地址:http://spamassassin.apache.org/publiccorpus/
这个是Apache Spam Assassin项目下的数据
二、训练模型
将数据放到hdfs上,如/tmp/train。目录里分别包含垃圾邮件目录spam和非垃圾邮件目录ham。
下面是模型训练脚本。脚本参考了mahout-distribution-0.10.0/examples/bin/classify-20newsgroups.sh。
#!/bin/sh mahout seqdirectory -i /tmp/train -o /tmp/train-seq -ow && \ hadoop fs -rmr /tmp/train-vectors && \ mahout seq2sparse -i /tmp/train-seq -o /tmp/train-vectors -lnorm -nv -wt tfidf if [ $? -ne 0 ]; then exit 1 fi echo "==============================" echo "seq2sparse finished" echo "==============================" mahout split -i /tmp/train-vectors/tfidf-vectors \ --trainingOutput /tmp/antispam-train-vectors \ --testOutput /tmp/antispam-test-vectors \ --randomSelectionPct 40 --overwrite --sequenceFiles -xm sequential && \ mahout trainnb -i /tmp/antispam-train-vectors -o /tmp/model -li /tmp/labelindex -ow -c echo "==============================" echo "trainnb finished" echo "=============================="
三、搭建邮件过滤web server
1. 构造classifier
package com.melot.antispam; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.mahout.classifier.AbstractVectorClassifier; import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.Vector.Element; import org.apache.mahout.vectorizer.TFIDF; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; /* * public class Classifier * * @desc Classifies tweets using a model trained in mahout. * * @author Vicente Ruben Del Pino Ruiz <<[email protected]>> * */ public class Classifier { private AbstractVectorClassifier classifier; private NaiveBayesModel naiveBayesModel; private Map<String, Integer> dictionary; private Map<Integer, Long> documentFrequency; private int documentCount; private Map<Integer, String> labelIndex; public final static String DICTIONARY_PATH_CONF = "dictionaryPath"; public final static String DOCUMENT_FREQUENCY_PATH_CONF = "documentFrequencyPath"; public final static String DOCUMENT_PATH_CONF = "documentPath"; public final static String MODEL_PATH_CONF = "modelPath"; public final static String LABEL_INDEX = "labelIndex"; /* * public Classifier * * @desc Initialize all the variables to be used and reads: --Dictionnary or * train set. --Document Frequency of train set. --Naive Bayes model. -- * * @param Configuration configuration. Configuration of the cluster. */ public Classifier(Configuration configuration) throws IOException { String modelPath = configuration.getStrings(MODEL_PATH_CONF)[0]; String dictionaryPath = configuration.getStrings(DICTIONARY_PATH_CONF)[0]; String documentFrequencyPath = configuration .getStrings(DOCUMENT_FREQUENCY_PATH_CONF)[0]; String documentPath = configuration.getStrings(DOCUMENT_PATH_CONF)[0]; String labelIndexPath = configuration.getStrings(LABEL_INDEX)[0]; dictionary = readDictionnary(configuration, new Path(dictionaryPath)); documentFrequency = readDocumentFrequency(configuration, new Path( documentFrequencyPath)); documentCount = readDocumentCount(configuration, new Path(documentPath)); labelIndex = readLabelIndex(configuration, new Path(labelIndexPath)); naiveBayesModel = NaiveBayesModel.materialize(new Path(modelPath), configuration); classifier = new StandardNaiveBayesClassifier(naiveBayesModel); } /* * private int getBestCategory * * @desc Gets the best category for a vector of scores * * @param Vector result. Vector with scores * * @return bestCategoryID. Best category for the vector */ private int getBestCategory(Vector result) { // Iterate through the scores and take the category with the higher. double bestScore = -Double.MAX_VALUE; int bestCategoryId = -1; for (Element element : result.all()) { int categoryId = element.index(); double score = element.get(); if (score > bestScore) { bestScore = score; bestCategoryId = categoryId; } } return bestCategoryId; } /* * private Vector generateTFIDFVector * * @desc Generates a TFIDF vector for words * * @param HasMap words. Words to use for generating the vector * * @param int wordCount. Number of words * * @param int documentCount. Document counting * * @return bestCategoryID. Best category for the vector */ private Vector generateTFIDFVector(HashMap<String, Integer> words, int wordCount, int documentCount) { Vector vector = new RandomAccessSparseVector(10000); TFIDF tfidf = new TFIDF(); // Create a TF-IDF vector for each tweet for (Entry<String, Integer> entry : words.entrySet()) { String word = entry.getKey(); int count = entry.getValue(); Integer wordId = dictionary.get(word); Long freq = documentFrequency.get(wordId); double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount); vector.setQuick(wordId, tfIdfValue); } return vector; } /* * public int classify * * @desc Main part of the class. Classifies a text using Naive Bayes and a * model created by Mahout. * * @param String text. Text of the tweet to classify */ public String classify(String text) throws IOException { HashMap<String, Integer> words = new HashMap<String, Integer>(); int bestCategoryID; // Create our own TF-IDF vector with the tweet text String delims = "[ ]+"; String[] ts = text.split(delims); int wordCount = 0; // Iterate through each word in the tweet and calculate its counting. for (int i = 0; i < ts.length; i++) { String word = ts[i]; Integer wordId = dictionary.get(word); // Only take words that are in our train set if (wordId != null) { if (!words.containsKey(word)) { words.put(word, 1); } else { int countWord = words.get(word) + 1; words.put(word, countWord); } wordCount++; } } // Generate TFIDF vector Vector vector = generateTFIDFVector(words, wordCount, documentCount); // Classify the TF-IDF vector created using Mahout model Vector result = classifier.classifyFull(vector); // Get the best category for the vector bestCategoryID = getBestCategory(result); return labelIndex.get(bestCategoryID); } /* * private static readDictionary * * @desc Reads the dictionnary and loads in memory. With bigs train sets * this consumes a high volume of memory. * * @param Configuration Conf. Configuration from the cluster * * @param Path dictionnaryPath. Path to the dictionnary file. */ private Map<String, Integer> readDictionnary(Configuration conf, Path dictionnaryPath) { Map<String, Integer> dictionnary = new HashMap<String, Integer>(); for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>( dictionnaryPath, true, conf)) { dictionnary.put(pair.getFirst().toString(), pair.getSecond().get()); } return dictionnary; } /* * private static readDocumentFrequency * * @desc Reads the document frequency and loads in memory. With bigs train * sets this consumes a high volume of memory. * * @param Configuration Conf. Configuration from the cluster * * @param Path documentFrequencyPath. Path to the document frequency file. */ private Map<Integer, Long> readDocumentFrequency(Configuration conf, Path documentFrequencyPath) { Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>(); for (Pair<IntWritable, LongWritable> pair : new SequenceFileIterable<IntWritable, LongWritable>( documentFrequencyPath, true, conf)) { documentFrequency .put(pair.getFirst().get(), pair.getSecond().get()); } return documentFrequency; } private Map<Integer, String> readLabelIndex(Configuration conf, Path labelIndexPath) { Map<Integer, String> labelIndex = new HashMap<Integer, String>(); for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>( labelIndexPath, true, conf)) { labelIndex.put(pair.getSecond().get(), pair.getFirst().toString()); } return labelIndex; } private int readDocumentCount(Configuration conf, Path documentPath) throws IOException { FileSystem fs = FileSystem.get(conf); FileStatus[] files = fs.listStatus(documentPath); int documentCount = 0; for (FileStatus file : files) { if (file.isDirectory()) { documentCount += readDocumentCount(conf, file.getPath()); } else { documentCount += 1; } } return documentCount; } }
2. 创建一个简单的HttpServlet
package com.melot.antispam; import java.io.BufferedReader; import java.io.IOException; import java.io.Reader; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.hadoop.conf.Configuration; @SuppressWarnings("serial") public class AntiSpamServlet extends HttpServlet { private Classifier classifier; @Override public void init() { Configuration conf = new Configuration(); conf.set("fs.defaultFS", Config.hdfsURI); conf.set(Classifier.DICTIONARY_PATH_CONF, Config.dictionaryPath); conf.set(Classifier.DOCUMENT_FREQUENCY_PATH_CONF, Config.documentFrequencyPath); conf.set(Classifier.DOCUMENT_PATH_CONF, Config.documentPath); conf.set(Classifier.MODEL_PATH_CONF, Config.modelPath); conf.set(Classifier.LABEL_INDEX, Config.labelIndexPath); try { classifier = new Classifier(conf); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { String email = readAll(req.getReader()); long t0 = System.currentTimeMillis(); String category = classifier.classify(email); long t1 = System.currentTimeMillis(); resp.getWriter().print( String.format("{\"category\":\"%s\", \"time\": %d}", category, t1 - t0)); } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { String email = readAll(req.getReader()); long t0 = System.currentTimeMillis(); String category = classifier.classify(email); long t1 = System.currentTimeMillis(); resp.getWriter().print( String.format("{\"category\":\"%s\", \"time\": %d}", category, t1 - t0)); } private String readAll(Reader reader) throws IOException { BufferedReader bufReader = new BufferedReader(reader); String line; StringBuilder sb = new StringBuilder(); while ((line = bufReader.readLine()) != null) { sb.append(line).append(" "); } return sb.toString(); } }
详细代码可以参考我在OSChina上的git:https://git.oschina.net/moriarty279/antispam.git
四、参考资料
针对旧版本mahout的垃圾邮件过滤模型例子——https://emmaespina.wordpress.com/2011/04/26/ham-spam-and-elephants-or-how-to-build-a-spam-filter-server-with-mahout/
主要代码来自https://github.com/MovieTrender/TwitterClassifier.git,但是它获取documentCount的方式有问题