理论
理论强推阮一峰大神的个人网站
1.贝叶斯推断及其互联网应用(一):定理简介
2.贝叶斯推断及其互联网应用(二):过滤垃圾邮件
非常简明易懂,然后我下面的代码就是实现上面过滤垃圾邮件算法的。
前期准备
数据来源
数据来源于《机器学习实战》中的第四章朴素贝叶斯分类器的实验数据。数据书上只提供了50条数据(25条正常邮件,25条垃圾邮件),感觉数据量偏小,以后打算使用scikit-learn提供的iris数据。
这里需要说明下,贝叶斯推断和朴素贝叶斯不是同一个概念
数据准备
和很多机器学习一样,数据需要拆分成训练集和测试集。
拆分训练集和测试集的思路如下:
1.遍历包含50条数据的email文件夹,获取文件列表
2.使用random.shuffle()函数打乱列表
3.截取乱序后的文件列表前10个文件路径,并转移到test文件夹下,作为测试集。
代码实现:
# -*- coding: utf-8 -*-
# @Date : 2017-05-09 13:06:56
# @Author : Alan Lau ([email protected])
# @Language : Python3.5
# from fwalker import fun
import random
# from reader import writetxt, readtxt
import shutil
import os
def fileWalker(path):
fileArray = []
for root, dirs, files in os.walk(path):
for fn in files:
eachpath = str(root+'\\'+fn)
fileArray.append(eachpath)
return fileArray
def main():
filepath = r'..\email'
testpath = r'..\test'
files = fileWalker(filepath)
random.shuffle(files)
top10 = files[:10]
for ech in top10:
ech_name = testpath+'\\'+('_'.join(ech.split('\\')[-2:]))
shutil.move(ech, testpath)
os.rename(testpath+'\\'+ech.split('\\')[-1], ech_name)
print('%s moved' % ech_name)
if __name__ == '__main__':
main()
最后获取的文件列表如下:
copy是备份数据,防止操作错误
ham文件列表:
spam文件列表:
test文件列表:
可见,数据准备后的测试集中,有7个垃圾邮件,3个正常的邮件。
代码实现
# -*- coding: utf-8 -*-
# @Date : 2017-05-09 09:29:13
# @Author : Alan Lau ([email protected])
# @Language : Python3.5
# from fwalker import fun
# from reader import readtxt
import os
def readtxt(path,encoding):
with open(path, 'r', encoding = encoding) as f:
lines = f.readlines()
return lines
def fileWalker(path):
fileArray = []
for root, dirs, files in os.walk(path):
for fn in files:
eachpath = str(root+'\\'+fn)
fileArray.append(eachpath)
return fileArray
def email_parser(email_path):
punctuations = """,.<>()*&^%$#@!'";~`[]{}|、\\/~+_-=?"""
content_list = readtxt(email_path, 'utf8')
content = (' '.join(content_list)).replace('\r\n', ' ').replace('\t', ' ')
clean_word = []
for punctuation in punctuations:
content = (' '.join(content.split(punctuation))).replace(' ', ' ')
clean_word = [word.lower()
for word in content.split(' ') if len(word) > 2]
return clean_word
def get_word(email_file):
word_list = []
word_set = []
email_paths = fileWalker(email_file)
for email_path in email_paths:
clean_word = email_parser(email_path)
word_list.append(clean_word)
word_set.extend(clean_word)
return word_list, set(word_set)
def count_word_prob(email_list, union_set):
word_prob = {}
for word in union_set:
counter = 0
for email in email_list:
if word in email:
counter += 1
else:
continue
prob = 0.0
if counter != 0:
prob = counter/len(email_list)
else:
prob = 0.01
word_prob[word] = prob
return word_prob
def filter(ham_word_pro, spam_word_pro, test_file):
test_paths = fileWalker(test_file)
for test_path in test_paths:
email_spam_prob = 0.0
spam_prob = 0.5
ham_prob = 0.5
file_name = test_path.split('\\')[-1]
prob_dict = {}
words = set(email_parser(test_path))
for word in words:
Psw = 0.0
if word not in spam_word_pro:
Psw = 0.4
else:
Pws = spam_word_pro[word]
Pwh = ham_word_pro[word]
Psw = spam_prob*(Pws/(Pwh*ham_prob+Pws*spam_prob))
prob_dict[word] = Psw
numerator = 1
denominator_h = 1
for k, v in prob_dict.items():
numerator *= v
denominator_h *= (1-v)
email_spam_prob = round(numerator/(numerator+denominator_h), 4)
if email_spam_prob > 0.5:
print(file_name, 'spam', email_spam_prob)
else:
print(file_name, 'ham', email_spam_prob)
# print(prob_dict)
# print('******************************************************')
# break
def main():
ham_file = r'..\email\ham'
spam_file = r'..\email\spam'
test_file = r'..\email\test'
ham_list, ham_set = get_word(ham_file)
spam_list, spam_set = get_word(spam_file)
union_set = ham_set | spam_set
ham_word_pro = count_word_prob(ham_list, union_set)
spam_word_pro = count_word_prob(spam_list, union_set)
filter(ham_word_pro, spam_word_pro, test_file)
if __name__ == '__main__':
main()
实验结果
ham_24.txt ham 0.0005
ham_3.txt ham 0.0
ham_4.txt ham 0.0
spam_11.txt spam 1.0
spam_14.txt spam 0.9999
spam_17.txt ham 0.0
spam_18.txt spam 0.9992
spam_19.txt spam 1.0
spam_22.txt spam 1.0
spam_8.txt spam 1.0
可见正确率为90%,实际上严格来说,应当将所有数据随机均分十组,每一组轮流作为一次测试集,剩下九组作为训练集,再将十次计算结果求均值,这个模型求出的分类效果才具有可靠性,其次,数据量小导致准确率较小的原因不排除在外。
所有代码以及数据GITHUB