import csv
from math import sqrt, pi, exp, log
def read_data(filename):
'''
读数据,西瓜数据集3.0
:return: 返回数据集X 和标签集Y
'''
X, Y = [], []
with open(filename) as f:
reader = csv.reader(f)
header_row = next(reader)
for line in reader:
X.append(line[1: 10])
Y.append(line[9])
return X, Y
def LNB(X, Y):
'''
训练 拉普拉斯朴素贝叶斯分类器 ‘查表’法
:param X: 数据集,包含标签
:param Y: 标签集
:return: 返回模型参数
'''
# 标签概率表
total_num = len(Y)
unique_class_dict = count_list(Y)
num_class = len(unique_class_dict)
p_class = {}
for c, n in unique_class_dict.items():
p_class[c] = float(n + 1) / (total_num + num_class)
# 属性概率表
unique_count_discrete_positive_p, unique_count_discrete_negative_p = [], []
for i in range(6):
discrete_positive, discrete_negative = [], []
for j in range(total_num):
if X[j][-1] == '是':
discrete_positive.append(X[j][i])
else:
discrete_negative.append(X[j][i])
unique_positive_dict = count_list(discrete_positive)
unique_negative_dict = count_list(discrete_negative)
unique_count_discrete_positive_p.append(p_discrete(unique_positive_dict,
unique_class_dict['是']))
unique_count_discrete_negative_p.append(p_discrete(unique_negative_dict,
unique_class_dict['否']))
# 连续属性 均值,方差
means_positive = []
vars_positive = []
means_negative = []
vars_negative = []
for i in range(2):
continuous_positive, continuous_negative = [], []
for j in range(total_num):
if X[j][-1] == '是':
continuous_positive.append(float(X[j][i + 6]))
else:
continuous_negative.append(float(X[j][i + 6]))
mean_positive, var_positive = mean_var(continuous_positive)
mean_negative, var_negative = mean_var(continuous_negative)
means_positive.append(mean_positive)
vars_positive.append(var_positive)
means_negative.append(mean_negative)
vars_negative.append(var_negative)
return p_class, unique_count_discrete_positive_p, unique_count_discrete_negative_p,/
means_positive, vars_positive, means_negative, vars_negative
def count_list(data_list):
'''
属性不同取值次数
:param data_list:输入一种属性数据
:return: 字典形式返回
'''
unique_dict = {}
for e in set(data_list):
unique_dict[e] = data_list.count(e)
return unique_dict
def p_discrete(unique_count_dict, class_count):
'''
离散属性概率预测
:param unique_count_dict: 属性值及次数,字典形式
:param class_count: 标签取值个数,例如正8,负9
:return: 返回属性取值概率 字典
'''
unique_p_dict = {}
for a, n in unique_count_dict.items():
unique_p_dict[a] = float(n + 1) / (class_count + len(unique_count_dict))
return unique_p_dict
def p_continuous(x, mean, var):
'''
连续属性概率预测
:param x: 属性取值
:param mean: 均值
:param var: 方差
:return: 返回概率值
'''
p = 1.0 / (sqrt(2 * pi) * sqrt(var)) * exp(- (x - mean) ** 2 / (2 * var))
return p
def mean_var(data_list):
'''
计算连续数据集均值和方差
:param data_list:单属性数据集
:return: 返回数据集均值和方差
'''
mean = sum(data_list) / float(len(data_list))
var = 0
for i in range(len(data_list)):
var += (data_list[i] - mean) ** 2
var = var / float(len(data_list))
return mean, var
def predict(data, p_class, unique_count_discrete_positive_p,
unique_count_discrete_negative_p, means_positive, vars_positive,
means_negative, vars_negative):
'''
利用’查表‘法,进行预测
:return: 返回预测结果
'''
# 采用对数形式,防止下溢
p_positive = log(p_class['是'])
p_negative = log(p_class['否'])
for i in range(6):
p_positive += log(unique_count_discrete_positive_p[i][data[i]])
p_negative += log(unique_count_discrete_negative_p[i][data[i]])
for i in range(2):
p_positive += log(p_continuous(data[i + 6], means_positive[i],
vars_positive[i]))
p_negative += log(p_continuous(data[i + 6], means_negative[i],
vars_negative[i]))
if p_positive >= p_negative:
return p_positive, p_negative, '好瓜'
else:
return p_positive, p_negative, '坏瓜'
if __name__ == '__main__':
filename = "C:\\Users\\14399\\Desktop\\西瓜3.0.csv"
X, Y = read_data(filename)
# 训练模型参数
p_class, unique_count_discrete_positive_p, unique_count_discrete_negative_p, \
means_positive, vars_positive, means_negative, vars_negative = LNB(X, Y)
# 测试例
data_test = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460]
# 预测结果
predict_result = predict(data_test, p_class, unique_count_discrete_positive_p,
unique_count_discrete_negative_p,means_positive, vars_positive,
means_negative, vars_negative)
print('predict result:', predict_result)
结果:predict result: (-3.4445474170196295, -9.920464036770394, '好瓜')
西瓜3.0数据集:链接:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA 密码:3h6n
参考:https://blog.csdn.net/VictoriaW/article/details/78260103