使用weka进行文本分类的例子

import java.io.BufferedReader;

import java.io.File;

import java.io.FileInputStream;

import java.io.FileNotFoundException;

import java.io.FileOutputStream;

import java.io.FileReader;

import java.io.ObjectInputStream;

import java.io.ObjectOutputStream;

import java.io.Serializable;

import weka.classifiers.Classifier;

import weka.classifiers.bayes.NaiveBayes;

import weka.classifiers.trees.J48;

import weka.core.Attribute;

import weka.core.FastVector;

import weka.core.Instance;

import weka.core.Instances;

import weka.core.Utils;

import weka.filters.Filter;

import weka.filters.unsupervised.attribute.StringToWordVector;



public class MessageClassify implements Serializable 

{

	private Instances instances = null;	

	private StringToWordVector filter = new StringToWordVector();

	private Classifier classifier = new NaiveBayes();

	

	/**

	 * 构造分类器,主要及时对数据格式,类标,类别数目等进行说明

	 */

	public MessageClassify() throws Exception

	{

		String nameOfDataset = "MessageClassification";		

		FastVector attributes = new FastVector(2);

		attributes.addElement(new Attribute("Message", (FastVector) null));

		FastVector classValues = new FastVector(2);//类标向量,共有两类

		classValues.addElement("alt.atheism");

		classValues.addElement("comp.graphics");

		attributes.addElement(new Attribute("Class", classValues));				

		instances = new Instances(nameOfDataset, attributes, 100);//可以把instance认为是行,attribute认为是列

		instances.setClassIndex(instances.numAttributes() - 1);//类表在instance中的那列

	}



	/**

	 * 添加数据到训练集中

	 */

	public void updateData(String message, String classValue) throws Exception

	{

		Instance instance = makeInstance(message, instances);

		instance.setClassValue(classValue);

		instances.add(instance);	

	}

	

	/**

	 * 文本分类要特别一点,因为在使用StringToWordVector对象计算文本中词项(attribute)权重的时候需要用到全局变量,比如DF,所以这里需要批量处理

	 * 在weka中要注意有些机器学习算法是批处理有些不是

	 */

	public void finishBatch() throws Exception

	{

		filter.setInputFormat(instances);

		Instances filteredData = Filter.useFilter(instances, filter);//这才真正产生符合weka算法输入格式的数据集

		classifier.buildClassifier(filteredData);//真正的训练分类器

	}

	

	/**

	 * 分类过程

	 */

	public void classifyMessage(String message) throws Exception

	{			

		filter.input(makeInstance(message, instances.stringFreeStructure()));

		Instance filteredInstance = filter.output();//必须使用原来的filter

		

		double predicted = classifier.classifyInstance(filteredInstance);//(int)predicted是类标索引

		System.out.println("Message classified as : "

				+ instances.classAttribute().value((int) predicted));

	}



	

	private Instance makeInstance(String text, Instances data)

	{		

		Instance instance = new Instance(2);		

		Attribute messageAtt = data.attribute("Message");		

		instance.setValue(messageAtt, messageAtt.addStringValue(text));		

		instance.setDataset(data);

		return instance;

	}



	public static String getStringFromFile(File file)

	{

		StringBuilder sb=new StringBuilder();

		try

		{

			BufferedReader br=new BufferedReader(new FileReader(file));

			String line;

			while(true)

			{				

				if((line=br.readLine())==null) break;

				sb.append(line.trim());

			}		

			br.close();

		} catch (Exception e){}

		return sb.toString();

	}

	

	

	static String modelname="weka.message";	

	public static void main(String[] options)

	{

		try

		{

			MessageClassify messageCl=null;

			if(new File(modelname).exists())

				messageCl=loadModel(modelname);

			else

			{

				messageCl=trainModel();

				try

				{

					ObjectOutputStream modelOutObjectFile = new ObjectOutputStream(new FileOutputStream(modelname));

					modelOutObjectFile.writeObject(messageCl);

					modelOutObjectFile.close();

				} catch (Exception e){}

			}

			String testPath="E:/datasets/20_newsgroups_two_class/test/alt.atheism/54262";

			messageCl.classifyMessage(getStringFromFile(new File(testPath)));		

		} 

		catch (Exception e){}

	}



	/**

	 * 训练分类器

	 */

	private static MessageClassify trainModel()

	{

		MessageClassify mc=null;

		try

		{

			mc = new MessageClassify();			

			String basePath="E:/datasets/20_newsgroups_two_class/train";

			File base=new File(basePath);

			for(File dir : base.listFiles())

			{

				for(File file : dir.listFiles())

				{

					String message=getStringFromFile(file);

					String classValue=dir.getName();

					mc.updateData(message, classValue);//添加一条训练样本,classvalue就是类标

				}				

			}

			mc.finishBatch();//训练过程

			

		} catch (Exception e){}

		return mc;

	}



	public static MessageClassify loadModel(String modelname)

	{

		MessageClassify mc=null;

		try

		{

			ObjectInputStream modelInObjectFile =new ObjectInputStream(new FileInputStream(modelname));

			mc = (MessageClassify) modelInObjectFile.readObject();

			modelInObjectFile.close();

		}

		catch (Exception e){}

		return mc;		

	}

}

你可能感兴趣的:(例子)