朴素贝叶斯算法及Python的简单实现

朴素贝叶斯算法及Python的简单实现

贝叶斯算法起源于古典数学理论,是一种分类算法的总称。它以贝叶斯定理为基础,假设某待分类的样本满足某种概率分布,并且可以根据已观察到的样本数据对该样本进行概率计算,以得出最优的分类决策。通过计算已观察到的样本数据估计某待分类样本的先验概率,利用贝叶斯公式计算出其后验概率,即该样本属于某一类的概率,选择具有最大后验概率的类作为该样本所属的类。

先验概率是根据以往经验和分析得到的概率。可以用P(h)来表示假设h的先验概率,用P(D)表示将要观察的训练数据D的先验概率,用P(D|h)表示假设h成立的情况下观察到数据D的概率,关心的是P(h|D),即给定训练数据D时h成立的概率,称之为后验概率。

贝叶斯公式:P(h|D)=P(D|h)P(h)/P(D),从公式可以看出P(h|D)随着P(h)和P(D|h)的增长而增长,也可以看出P(h|D)随着P(D)的增加而减少。考虑候选假设集合H并在其中寻找给定数据D时可能性最大的假设h。这样的具有最大可能性的假设被称为极大后验(MAP)假设。确定MAP假设的方法是用贝叶斯公式计算每个候选假设的后验概率,可能会多个属性,且属性之间可能存在复杂的依赖关系,这就使得P(D|h)的计算十分困难,为了简化条件概率的求解难度,提出了一种条件独立假设,即假设训练数据D中,各属性之间相互独立。在贝叶斯算法基础上添上条件独立假设,我们就称之为朴素贝叶斯算法。

# -*- coding:utf-8 -*-
import numpy as np
__author__ = 'yangxin'
"""
贝叶斯公式
p(xy)=p(x|y)p(y)=p(y|x)p(x)
p(x|y)=p(y|x)p(x)/p(y)
"""
 
 
class SpeechJudgment(object):
 
    def load_data_set(self):
        # 单词列表
        posting_list = [
            ['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
            ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
            ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
            ['stop', 'posting', 'stupid', 'worthless', 'gar e'],
            ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
            ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
        # 属性类别列表 1 -> 侮辱性的文字, 0 -> not
        class_vec = [0, 1, 0, 1, 0, 1]
        return posting_list, class_vec
 
    def create_vocab_list(self, data_set):
        vocab_set = set()
        for item in data_set:
            vocab_set = vocab_set | set(item)
        # 不含重复元素的单词列表
        return list(vocab_set)
 
    def set_of_words2vec(self, vocab_list, input_set):
        result = [0] * len(vocab_list)
        for word in input_set:
            if word in vocab_list:
                # 如单词在输入文档出现过,则标记为1,否则为0
                result[vocab_list.index(word)] = 1
        return result
 
    def train_naive_bayes(self, train_mat, train_category):
        train_doc_num = len(train_mat)
        words_num = len(train_mat[0])
        pos_abusive = np.sum(train_category) / train_doc_num
        # 创建一个长度为words_num的都是1的列表
        p0num = np.ones(words_num)
        p1num = np.ones(words_num)
        p0num_all = 2.0
        p1num_all = 2.0
        for i in range(train_doc_num):
            if train_category[i] == 1:
                p1num += train_mat[i]
                p1num_all += np.sum(train_mat[i])
            else:
                p0num += train_mat[i]
                p0num_all += np.sum(train_mat[i])
        p1vec = np.log(p1num / p1num_all)
        p0vec = np.log(p0num / p0num_all)
        return p0vec, p1vec, pos_abusive
 
    def classify_naive_bayes(self, vec_to_classify, p0vec, p1vec, p_class1):
        p1 = np.sum(vec_to_classify * p1vec) + np.log(p_class1)
        p0 = np.sum(vec_to_classify * p0vec) + np.log(1 - p_class1)
        if p1 > p0:
            return 1
        else:
            return 0
 
    def bag_words_to_vec(self, vocab_list, input_set):
        result = [0] * len(vocab_list)
        for word in input_set:
            if word in vocab_list:
                result[vocab_list.index(word)] += 1
            else:
                print('the word: {} is not in my vocabulary'.format(word))
        return result
 
 
    def testing_naive_bayes(self):
        list_post, list_classes = self.load_data_set()
        vocab_list = self.create_vocab_list(list_post)
        train_mat = []
        for post_in in list_post:
            train_mat.append(
                self.set_of_words_to_vec(vocab_list, post_in)
            )
        p0v, p1v, p_abusive = self.train_naive_bayes(np.array(train_mat), np.array(list_classes))
        test_one = ['love', 'my', 'dalmation']
        test_one_doc = np.array(self.set_of_words2vec(vocab_list, test_one))
        print('the result is: {}'.format(self.classify_naive_bayes(test_one_doc, p0v, p1v, p_abusive)))
        test_two = ['stupid', 'garbage']
        test_two_doc = np.array(self.set_of_words2vec(vocab_list, test_two))
        print('the result is: {}'.format(self.classify_naive_bayes(test_two_doc, p0v, p1v, p_abusive)))

你可能感兴趣的:(算法)