Java实现的朴素贝叶斯分类器

目前的算法只能处理结果只有两种的情况,即true or false. 多分枝或者是数字类型的还无法处理。

用到的一些基础数据结构可以参考上一篇关于ID3的代码。 

 

这里只贴出来实现贝叶斯分类预测的部分:

package classifier;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import util.ArffUtil;


/**
 * NBC means Naive Bayes Classifier
 * @author wenjun_yang
 *
 */
public class NBCUtil {
	
	ArffUtil util = new ArffUtil();
	private List attributeList = null;
	private List dataList = null;
	private String decAttributeName = null;
	private int decAttributeIndex = -1;
	
	private Map> seperatedDataTable = null;
	public NBCUtil(String decAttributeName, List attributeList, List dataList) {
		this.attributeList = attributeList;
		this.dataList = dataList;
		this.decAttributeName = decAttributeName;
		
		this.decAttributeIndex = util.getValueIndex(decAttributeName, this.attributeList);
		this.seperatedDataTable = seperateDataList(dataList);
	}
	
	private Map> seperateDataList(List dataList) {
		Map> map = new HashMap>();
		
		for(String[] arr : dataList) {
			if(decAttributeIndex >= 0 && decAttributeIndex < arr.length) {
				String currentKey = arr[decAttributeIndex]; 
				if(map.containsKey(currentKey)) {
					List tempList = map.get(currentKey);
					tempList.add(arr);
					map.put(currentKey, tempList);
				} else {
					List tempList = new ArrayList();
					tempList.add(arr);
					map.put(currentKey , tempList);
				}
			}
		}
		
		return map;
	}
	
	public Boolean predict(Map predictData, String targetDecAttributeValue) {
		if(predictData.containsKey(decAttributeName)) predictData.remove(decAttributeName);
		
		List positiveDataTable = new ArrayList();
		if(seperatedDataTable.containsKey(targetDecAttributeValue)) {
			positiveDataTable = seperatedDataTable.get(targetDecAttributeValue);
		}
		
		double resultP = 1.;
		
		// Step 1: 逐个属性的比率进行计算
		// 即: 计算 P(Attr=Value|Y=true) / P(Attr=Value|Y=false) 的值
		for(String attrName : predictData.keySet()) {
			String attrValue = predictData.get(attrName);
			int attrIndex = util.getValueIndex(attrName, attributeList);
			int attrPositiveCount = 0;
			int attrNegativeCount = 0;
			
			for(String[] arr : dataList) {
				if(arr[attrIndex].equals(attrValue)) {
					if(arr[decAttributeIndex].equals(targetDecAttributeValue)) {
						attrPositiveCount++;
					} else {
						attrNegativeCount++;
					}
				}
			}
			double temp =  (attrPositiveCount / (double)positiveDataTable.size() ) /
							(attrNegativeCount / (double)(dataList.size() - positiveDataTable.size()));
			resultP *= temp;
		}
		// 最后计算 P(Y=true) / P(Y=false)
		resultP *= positiveDataTable.size() / (double)(dataList.size() - positiveDataTable.size());
		System.out.println(resultP);
		if(resultP > 1) {
			return true;
		} else {
			return false;
		}
	}
}

 

 

完整的项目也上传了,可以直接使用。

数据源来自weka

你可能感兴趣的:(代码仓库-Java,Java,数据挖掘,数据挖掘,Java,聚类)