数据挖掘之朴素贝叶斯算法的实现

这是我数据挖掘课的作业,也就是实现一个朴素贝叶斯算法。所用的训练数据集为加州大学计算机系提供的breast-cancer.data和segment.data。我得出的朴素贝叶斯算法对于离散型属性的预测准确度为0.72,对于连续型属性的预测准确度为0.79。

代码如下:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.beans.FeatureDescriptor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;


/**
 *
 * @author daq
 */

class Store{//用于标识离散属性的P(Xi|Cj)的类
	int attr;//哪个属性
	double attrValue;//属性对应的值
	double lable;//与之对应的标签的值
	
	@Override
	public int hashCode() {//重写的hashCode方法
		final int prime = 31;
		int result = 1;
		result = prime * result + attr;
		long temp;
		temp = Double.doubleToLongBits(attrValue);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		temp = Double.doubleToLongBits(lable);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		return result;
	}
	
	@Override
	public boolean equals(Object obj) {//重写的equals方法
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Store other = (Store) obj;
		if (attr != other.attr)
			return false;
		if (Double.doubleToLongBits(attrValue) != Double
				.doubleToLongBits(other.attrValue))
			return false;
		if (Double.doubleToLongBits(lable) != Double
				.doubleToLongBits(other.lable))
			return false;
		return true;
	}
}

class Store2{//用于标识连续属性的P(Xi|Cj)的类
	int attr;//哪个属性
	double label;//标签的值
	
	@Override
	public int hashCode() {//重写的hashCode方法
		final int prime = 31;
		int result = 1;
		result = prime * result + attr;
		long temp;
		temp = Double.doubleToLongBits(label);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		return result;
	}
	
	@Override
	public boolean equals(Object obj) {//重写的equals方法
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Store2 other = (Store2) obj;
		if (attr != other.attr)
			return false;
		if (Double.doubleToLongBits(label) != Double
				.doubleToLongBits(other.label))
			return false;
		return true;
	}
}



public class NaiveBayes extends Classifier {

	boolean []myIsCategory;
	double [][]myFeatures;
	double []myLabels;
	
	ArrayList labelKinds=new ArrayList();//label的数组的种类的数组
	HashMap labelKindsProp=new HashMap();//每种label的概率
	HashMap attrLabelsProp=new HashMap();
	HashMap> valueKinds=new HashMap>();
	HashMap averageAttrs=new HashMap();
	HashMapstandDev=new HashMap();//attr,labelKind
	
    public NaiveBayes() {
    	
    }

    public void setLabelKinds(double[] labels){//计算标签种类并存储各个不同标签值
    	for(int i=0;i values=valueKinds.get(i);
	    			if(values==null)
	    				values=new ArrayList();
	    			if(!values.contains(features[j][i])){
	    				values.add(features[j][i]);
	    				valueKinds.put(i,values);
	    			}
	    		}
    		}
    	}
    }
    
    public void setAttrLabelsProp(boolean[] isCategory, double[][] features){//对于离散的属性,计算不同的值占全部元祖的比例
    	for(int i=0;i values=valueKinds.get(j);
	    			int num[]=new int[values.size()];
	    			for(int k=0;kresMax){    			
    			resMax=res;
    			resMaxIndex=label;
    		}
    	}
    	return resMaxIndex;
    }
}

这段代码中,train函数用来训练数据,predict函数预测数据。train的参数isCategory数组是存储各个属性是连续的还是离散的,连续的话,值为0,离散的话,值为1。features[][]数组存放训练数据,label数组用来存储各个元组的标签值。

你可能感兴趣的:(数据挖掘之朴素贝叶斯算法的实现)