优点:在数据较少的情况下,仍然有效,可以处理多分类问题
缺点:对于输入数据的准备方式比较敏感
适用数据类型:标称型数据
###主要思想
p1(x,y)表示数据点(x,y)属于类别1的概率;
p2(x,y)表示数据点(x,y)属于类别2的概率;
if:
p1>p2 属于1类;
else
属于2类
//C,类别集合,D,用于训练的文本文件集合
TrainMultiNomialNB(C,D) {
// 单词出现多次,只算一个
V←ExtractVocabulary(D)
// 单词可重复计算
N←CountTokens(D)
for each c∈C
// 计算类别c下的单词总数
// N和Nc的计算方法和Introduction to Information Retrieval上的不同,个人认为
//该书是错误的,先验概率和类条件概率的计算方法应当保持一致
Nc←CountTokensInClass(D,c)
prior[c]←Nc/N
// 将类别c下的文档连接成一个大字符串
textc←ConcatenateTextOfAllDocsInClass(D,c)
for each t∈V
// 计算类c下单词t的出现次数
Tct←CountTokensOfTerm(textc,t)
for each t∈V
//计算P(t|c)
condprob[t][c]←
return V,prior,condprob
}
ApplyMultiNomialNB(C,V,prior,condprob,d) {
// 将文档d中的单词抽取出来,允许重复,如果单词是全新的,在全局单词表V中都
// 没出现过,则忽略掉
W←ExtractTokensFromDoc(V,d)
for each c∈C
score[c]←prior[c]
for each t∈W
if t∈Vd
score[c] *= condprob[t][c]
return max(score[c])
}
/************************************************************************/
//C,类别集合,D,用于训练的文本文件集合
TrainBernoulliNB(C, D) {
// 单词出现多次,只算一个
V←ExtractVocabulary(D)
// 计算文件总数
N←CountDocs(D)
for each c∈C
// 计算类别c下的文件总数
Nc←CountDocsInClass(D,c)
prior[c]←Nc/N
for each t∈V
// 计算类c下包含单词t的文件数
Nct←CountDocsInClassContainingTerm(D,c,t)
//计算P(t|c)
condprob[t][c]←(Nct+1)/(Nct+2)
return V,prior,condprob
}
ApplyBernoulliNB(C,V,prior,condprob,d) {
// 将文档d中单词表抽取出来,如果单词是全新的,在全局单词表V中都没出现过,
// 则舍弃
Vd←ExtractTermsFromDoc(V,d)
for each c∈C
score[c]←prior[c]
for each t∈V
if t∈Vd
score[c] *= condprob[t][c]
else
score[c] *= (1-condprob[t][c])
return max(score[c])
}
/**
*
*/
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.utils.MathM;
/**
* @author Home
*
*/
public class NBMain {
/**
* @param args
*/
final static String dataPath = "data.csv";
static List<String[]> dataList = new ArrayList<String[]>(); //训练集词集合
static List<float[]> vectorList = new ArrayList<float[]>(); //训练集词向量
static List<String> vocabList = new ArrayList<String>(); //词典
static float[] trainCategory; //训练集类别trainCategory=[0,1,0,1,0,1]
static int numTrainDocs = 0; //训练集文本数量
static int numwords = 0; //词典size
static MathM mm = new MathM();
public static void main(String[] args) throws IOException {
//加载数据集
vocabList = loadDataSet(dataPath);
Model model = trainBayes();
System.out.println(Arrays.toString(model.p0Vect));
System.out.println(Arrays.toString(model.p1Vect));
String[] test1 = {"love","my","dalmation"};
String[] test2 = {"stupid","garbage"};
System.out.println(classifyNB(setofWords2Vec(vocabList, test1,test1.length),model));
System.out.println(classifyNB(setofWords2Vec(vocabList, test2,test2.length),model));
}
public static int classifyNB(float[] vec2Classify,Model model){
double p1 = mm.multiply(vec2Classify,model.p1Vect)+Math.log(model.pAbusive);
double p0 = mm.multiply(vec2Classify,model.p0Vect)+Math.log(1-model.pAbusive);
if(p1>p0)
return 1;
else
return 0;
}
public static List<String> loadDataSet(String dataPath) throws IOException {
BufferedReader br = null;
String line;
try {
br = new BufferedReader(new FileReader(new File(dataPath)));
while ((line = br.readLine()) != null) {
String[] info = line.split(",");
dataList.add(info);
}
br.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
numTrainDocs = dataList.size();
List<String> vocabList = new ArrayList<String>();
trainCategory = new float[dataList.size()];
int j = 0;
for (String[] str : dataList) {
for (int i = 0; i < str.length - 1; i++)
if (!vocabList.contains(str[i]))
vocabList.add(str[i]);
trainCategory[j] = Integer.parseInt(str[str.length - 1]);
j++;
}
Collections.sort(vocabList);
for (String[] str : dataList) {
float[] temp = setofWords2Vec(vocabList, str,str.length-1);
vectorList.add(temp);
}
numwords = vocabList.size();
return vocabList;
}
public static float[] setofWords2Vec(List<String> vocabList,String[] postingDoc,int n){
//根据词典和词转化为词向量(onehot编码)
float[] temp = new float[vocabList.size()];
int index = -1;
for (int i = 0; i <n; i++) {
index = vocabList.indexOf(postingDoc[i]);
temp[index] = 1.0f;
}
return temp;
}
/*
* 1.首先计算属于侮辱性文档(class=1)的概率,即p(c1);p(c0) = 1-p(c1);
* 2.计算p(wi|c1)以及p(wi|c0)
*
* List dataList,List vectorList
*/
public static Model trainBayes() {
float pAbusive = (float) (mm.sum(trainCategory) / numTrainDocs);
float[] p0Num = new float[numwords];
float[] p1Num = new float[numwords];
Arrays.fill(p0Num, 1);
Arrays.fill(p1Num, 1);
//因为很多词出现次数为0,为使概率不为0.将所有词初始化为1,分母初始化为2
float p0Denom = 2.0f;
float p1Denom = 2.0f;
for (int i = 0; i < numTrainDocs; i++) {
float[] temp = vectorList.get(i);
if (trainCategory[i] == 1) {
p1Num = mm.dot(p1Num, temp);
p1Denom += mm.sum(temp);
} else {
p0Num = mm.dot(p0Num, temp);
p0Denom += mm.sum(temp);
}
}
//下溢出:由于p(w|c)很小,相乘为下溢出或得不到正确答案。方法对乘积取自然对数(ln(a*b) = ln(a)+ln(b))
float[] p1Vect = mm.fVect(p1Num,p1Denom);
float[] p0Vect = mm.fVect(p0Num,p0Denom);
Model m = new Model(p0Vect,p1Vect,pAbusive);
return m;
}
}
/**
*模型构造方法
*/
package com.loadData;
/**
* @author Home
*
*/
public class Model {
float[] p0Vect;
float[] p1Vect;
float pAbusive;
public float[] getP0Vect() {
return p0Vect;
}
public float[] getP1Vect() {
return p1Vect;
}
public void setP1Vect(float[] p1Vect) {
this.p1Vect = p1Vect;
}
public float getmodel() {
return pAbusive;
}
public void setpAbusive(float pAbusive) {
this.pAbusive = pAbusive;
}
public void setP0Vect(float[] p0Vect) {
this.p0Vect = p0Vect;
}
public Model(float[] p0Vect,float[] p1Vect,float pAbusive) {
this.p0Vect = p0Vect;
this.p1Vect = p1Vect;
this.pAbusive = pAbusive;
}
}
package com.utils;
/*
*自定义方法
* */
public class MathM {
public float sum(float[] R) {
float sum = 0;
for (float i : R)
sum += i;
return sum;
}
public float[] dot(float A[], float B[]) {
float C[] = new float[A.length];
for (int i = 0; i < C.length; i++)
C[i] = A[i] + B[i];
return C;
}
public float[] fVect(float[] A, float pDenom) {
float[] fvect = new float[A.length];
for (int i = 0; i < A.length; i++) {
fvect[i] = (float) Math.log(A[i] / pDenom);//A[i] / pDenom;//
}
return fvect;
}
public double multiply(float[] A,float[] B){
double C = 0;
for(int i=0;i<A.length;i++){
C+= A[i]*B[i];
}
return C;
}
}
data.csv
my,dog,has,flea,problem,help,please,0
maybe,not,take,him,to,dog,park,stupid,1
my,dalmation,is,so,cute,I,love,him,0
stop,posting,stupid,worthless,garbage,1
me,licks,ate,my,steak,how,to,stop,him,0
quit,buying,worthless,dog,food,stupid,1