朴素贝叶斯(NB classification)实现文档分类

编程环境:

anaconda + python3.7+Spyder+python3.6
Win10
完整代码及数据已经更新至GitHub,欢迎fork~GitHub链接


声明:创作不易,未经授权不得复制转载
statement:No reprinting without authorization


内容概述:

实现朴素贝叶斯分类器,测试其在20 Newsgroups数据集上的效果。

一、机器学习数据挖掘算法之朴素贝叶斯分类器原理介绍:

1.核心公式:

image.png

2.选取朴素贝叶斯的模型:

       一般有三种模型可供选取:伯努利模型、多项式模型、混合模型,伯努利将重复词只视为出现一次,会丢失词频信息,多项式统计和判断时都关注重复次数,混合模型结合前两种在测试时不考虑词频,但在统计时考虑重复次数;
结合效果发现多项式会有更好表现,于是选择了多项式模型

bayes_model

3.平滑技术的选择:

由于使用了多项式的模型,于是选择了更精确的平滑技术,计算的公式如下eg:

image.png

对应到本实验中即为:
P(“term”|class1)=class1中“term”出现的总次数 + 1)/class1中总词数+词表的总词数
利用朴素贝叶斯的条件独立性假设:
image.png

得到如下公式推导:


image.png

由于本实验所选样本和测试数据集分布都比较均匀,P(Vi)可以近似相等,所以P(Vi)可以约掉不用计算。

4.Tricks的应用:

       使用取对数来代替很多个概率相乘,原本相乘的这些概率P(term1|ci)P(term2|ci)P(term3|ci),他们的值大多都非常小,所以程序会下溢出或者结果被四舍五入后得到0,得到错误结果,使用对数可以完美解决。
       而且取对数后因为P(term|ci)的值都在0-1之间,而在这区间内对数函数变化很快,能够突出反映不同单词的概率对总体判断的影响,抵消部分当大量词在统计样本类别中未出现的情况,经过实验发现可以相当程度上增加准确率,如下图:
(1)对对数进行加1,使每个P(term|ci)都大于1,处于对数比较平缓的一段:


image.png

image.png

(2)对对数进行不加1,使每个P(term|ci)都在0-1之间,处于对数变化较快的一段:


image.png
image.png

       对比(1)和(2)可以发现处于对数变化较快时的准确率有较大提升(其中对于comp.os.ms-windows.misc类的数据发现无论怎么调整参数都无法提高准确率,怀疑该类数据的文档间没什么关联性)。

二、数据集介绍与划分:

       数据集为新闻文档数据集,有若干个类别,每个类别有数百个文档,数据集链接:#20news-18828.tar.gz:http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz 20 Newsgroups; duplicates removed, only "From" and "Subject" headers (18828 documents)
       按2:1进行划分,用来统计(训练)的样本占三分之二,测试样本三分之一,由于全部数据集太多,个人电脑跑起来比较费时,于是选取5到6类来进行实验,如下所示:
类别名:(测试文档数/训练文档数)
alt.atheism:(260 / 539)
comp.graphics:(321 / 652)
......

image.png

image.png

文本示例:
image.png

五、具体python代码实现细节:

       使用doc_filenames={}来分别存储在统计时记录某一个类别的所有文档的读取路径,使用postings[term]=number_of_term来记录建立每个类别的词典,并记录对应单词在该类别中的出现次数,使用num_cx来记录每个类别的总次数;total_aRate表示总的在测试集上的准确率。

1、对于每一类统计用数据集遍历每个文档得到其postings[]:
for id in doc1_filenames:
        f = open(doc1_filenames[id],'r',encoding='utf-8',errors='ignore')
        document = f.read()
        f.close()
        terms = tokenize(document)
        num_c1+=len(terms)#类1总词数
        unique_terms = set(terms)
        for term in unique_terms:
            if term not in postings1:
                postings1[term] = (terms.count(term))
            else:
                postings1[term]=(postings1[term]+(terms.count(term)))
2、Tokenize使用和SVM同样的方法:
def tokenize(document):    
    document=document.lower()
    document=re.sub(r"\W|\d|_|\s{2,}"," ",document)
    terms=TextBlob(document).words.singularize()

    result=[]
    for word in terms:
        expected_str = Word(word)
        expected_str = expected_str.lemmatize("v")
        result.append(expected_str)
    return result 
3、对测试数据集上进行判断时,对文档进行同样的tokenize,而后计算分别属于每个类别的概率(取对数后),选择最大的作为该文本的类别。
for id in doc1_test:
        f = open(doc1_test[id],'r',encoding='utf-8',errors='ignore')
        document = f.read()
        f.close()
        terms = tokenize(document)
        p=[0,0,0,0,0]
        for term in terms:
            if term in postings1:
                p[0]+=math.log((postings1[term]+1)/(num_c1+nc1))
            else:
                p[0]+=math.log(1/(num_c1+nc1))
                
            if term in postings2:
                p[1]+=math.log((postings2[term]+1)/(num_c2+nc2))
            else:
                p[1]+=math.log(1/(num_c2+nc2))
                
            if term in postings3:
                p[2]+=math.log((postings3[term]+1)/(num_c3+nc3))
            else:
                p[2]+=math.log(1/(num_c3+nc3))
                
            if term in postings4:
                p[3]+=math.log((postings4[term]+1)/(num_c4+nc4))
            else:
                p[3]+=math.log(1/(num_c4+nc4))
            if term in postings5:
                p[4]+=math.log((postings5[term]+1)/(num_c5+nc5))
            else:
                p[4]+=math.log(1/(num_c5+nc5))
        
        if p[ss]==max(p):
            count1+=1
    print(ss+1,"类名:",Newspath1[74:])
    print("判对文档数:",count1,"总的文档数:",len(doc1_test))
    total_aRate = (total_aRate+count1/len(doc1_test))
    print("准确率为:",count1/len(doc1_test))

六、最终结果展示:

image.png

基本平均准确率能达到百分之90以上,达到预期实验效果。

你可能感兴趣的:(朴素贝叶斯(NB classification)实现文档分类)