西瓜数 课后习题7.3 朴素贝叶斯分类器 拉普拉斯修正

import csv
from math import sqrt, pi, exp, log


def read_data(filename):
    '''
    读数据,西瓜数据集3.0
    :return: 返回数据集X 和标签集Y
    '''
    X, Y = [], []
    with open(filename) as f:
        reader = csv.reader(f)
        header_row = next(reader)
        for line in reader:
            X.append(line[1: 10])
            Y.append(line[9])
    return X, Y


def LNB(X, Y):
    '''
    训练 拉普拉斯朴素贝叶斯分类器 ‘查表’法
    :param X: 数据集,包含标签
    :param Y: 标签集
    :return: 返回模型参数
    '''
    # 标签概率表
    total_num = len(Y)
    unique_class_dict = count_list(Y)
    num_class = len(unique_class_dict)
    p_class = {}
    for c, n in unique_class_dict.items():
        p_class[c] = float(n + 1) / (total_num + num_class)
    # 属性概率表
    unique_count_discrete_positive_p, unique_count_discrete_negative_p = [], []
    for i in range(6):
        discrete_positive, discrete_negative = [], []
        for j in range(total_num):
            if X[j][-1] == '是':
                discrete_positive.append(X[j][i])
            else:
                discrete_negative.append(X[j][i])
        unique_positive_dict = count_list(discrete_positive)
        unique_negative_dict = count_list(discrete_negative)
        unique_count_discrete_positive_p.append(p_discrete(unique_positive_dict, 
                                                        unique_class_dict['是']))
        unique_count_discrete_negative_p.append(p_discrete(unique_negative_dict, 
                                                        unique_class_dict['否']))
    # 连续属性 均值,方差
    means_positive = []
    vars_positive = []
    means_negative = []
    vars_negative = []
    for i in range(2):
        continuous_positive, continuous_negative = [], []
        for j in range(total_num):
            if X[j][-1] == '是':
                continuous_positive.append(float(X[j][i + 6]))
            else:
                continuous_negative.append(float(X[j][i + 6]))
        mean_positive, var_positive = mean_var(continuous_positive)
        mean_negative, var_negative = mean_var(continuous_negative)
        means_positive.append(mean_positive)
        vars_positive.append(var_positive)
        means_negative.append(mean_negative)
        vars_negative.append(var_negative)

    return p_class, unique_count_discrete_positive_p, unique_count_discrete_negative_p,/ 
           means_positive, vars_positive, means_negative, vars_negative


def count_list(data_list):
    '''
    属性不同取值次数
    :param data_list:输入一种属性数据
    :return: 字典形式返回
    '''
    unique_dict = {}
    for e in set(data_list):
        unique_dict[e] = data_list.count(e)
    return unique_dict


def p_discrete(unique_count_dict, class_count):
    '''
    离散属性概率预测
    :param unique_count_dict: 属性值及次数,字典形式
    :param class_count: 标签取值个数,例如正8,负9
    :return: 返回属性取值概率 字典
    '''
    unique_p_dict = {}
    for a, n in unique_count_dict.items():
        unique_p_dict[a] = float(n + 1) / (class_count + len(unique_count_dict))
    return unique_p_dict


def p_continuous(x, mean, var):
    '''
    连续属性概率预测
    :param x: 属性取值
    :param mean: 均值
    :param var: 方差
    :return: 返回概率值
    '''
    p = 1.0 / (sqrt(2 * pi) * sqrt(var)) * exp(- (x - mean) ** 2 / (2 * var))
    return p


def mean_var(data_list):
    '''
    计算连续数据集均值和方差
    :param data_list:单属性数据集
    :return: 返回数据集均值和方差
    '''
    mean = sum(data_list) / float(len(data_list))
    var = 0
    for i in range(len(data_list)):
        var += (data_list[i] - mean) ** 2
    var = var / float(len(data_list))
    return mean, var


def predict(data, p_class, unique_count_discrete_positive_p, 
             unique_count_discrete_negative_p, means_positive, vars_positive, 
             means_negative, vars_negative):
    '''
    利用’查表‘法,进行预测
    :return: 返回预测结果
    '''
    # 采用对数形式,防止下溢
    p_positive = log(p_class['是'])
    p_negative = log(p_class['否'])
    for i in range(6):
        p_positive += log(unique_count_discrete_positive_p[i][data[i]])
        p_negative += log(unique_count_discrete_negative_p[i][data[i]])
    for i in range(2):
        p_positive += log(p_continuous(data[i + 6], means_positive[i], 
                          vars_positive[i]))
        p_negative += log(p_continuous(data[i + 6], means_negative[i], 
                          vars_negative[i]))

    if p_positive >= p_negative:
        return p_positive, p_negative, '好瓜'
    else:
        return p_positive, p_negative, '坏瓜'


if __name__ == '__main__':
    filename = "C:\\Users\\14399\\Desktop\\西瓜3.0.csv"
    X, Y = read_data(filename)
    # 训练模型参数
    p_class, unique_count_discrete_positive_p, unique_count_discrete_negative_p, \
    means_positive, vars_positive, means_negative, vars_negative = LNB(X, Y)
    # 测试例
    data_test = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460]
    # 预测结果
    predict_result = predict(data_test, p_class, unique_count_discrete_positive_p,         
                     unique_count_discrete_negative_p,means_positive, vars_positive, 
                     means_negative, vars_negative)
    print('predict result:', predict_result)

结果:predict result: (-3.4445474170196295, -9.920464036770394, '好瓜')

西瓜3.0数据集:链接:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA         密码:3h6n

参考:https://blog.csdn.net/VictoriaW/article/details/78260103 

你可能感兴趣的:(机器学习)