基于python的垃圾邮件分类_python实现贝叶斯推断——垃圾邮件分类

理论

理论强推阮一峰大神的个人网站

1.贝叶斯推断及其互联网应用(一):定理简介

2.贝叶斯推断及其互联网应用(二):过滤垃圾邮件

非常简明易懂,然后我下面的代码就是实现上面过滤垃圾邮件算法的。

前期准备

数据来源

数据来源于《机器学习实战》中的第四章朴素贝叶斯分类器的实验数据。数据书上只提供了50条数据(25条正常邮件,25条垃圾邮件),感觉数据量偏小,以后打算使用scikit-learn提供的iris数据。

这里需要说明下,贝叶斯推断和朴素贝叶斯不是同一个概念

数据准备

和很多机器学习一样,数据需要拆分成训练集和测试集。

拆分训练集和测试集的思路如下:

1.遍历包含50条数据的email文件夹,获取文件列表

2.使用random.shuffle()函数打乱列表

3.截取乱序后的文件列表前10个文件路径,并转移到test文件夹下,作为测试集。

代码实现:

# -*- coding: utf-8 -*-

# @Date : 2017-05-09 13:06:56

# @Author : Alan Lau ([email protected])

# @Language : Python3.5

# from fwalker import fun

import random

# from reader import writetxt, readtxt

import shutil

import os

def fileWalker(path):

fileArray = []

for root, dirs, files in os.walk(path):

for fn in files:

eachpath = str(root+'\\'+fn)

fileArray.append(eachpath)

return fileArray

def main():

filepath = r'..\email'

testpath = r'..\test'

files = fileWalker(filepath)

random.shuffle(files)

top10 = files[:10]

for ech in top10:

ech_name = testpath+'\\'+('_'.join(ech.split('\\')[-2:]))

shutil.move(ech, testpath)

os.rename(testpath+'\\'+ech.split('\\')[-1], ech_name)

print('%s moved' % ech_name)

if __name__ == '__main__':

main()

最后获取的文件列表如下:

copy是备份数据,防止操作错误

ham文件列表:

spam文件列表:

test文件列表:

可见,数据准备后的测试集中,有7个垃圾邮件,3个正常的邮件。

代码实现

# -*- coding: utf-8 -*-

# @Date : 2017-05-09 09:29:13

# @Author : Alan Lau ([email protected])

# @Language : Python3.5

# from fwalker import fun

# from reader import readtxt

import os

def readtxt(path,encoding):

with open(path, 'r', encoding = encoding) as f:

lines = f.readlines()

return lines

def fileWalker(path):

fileArray = []

for root, dirs, files in os.walk(path):

for fn in files:

eachpath = str(root+'\\'+fn)

fileArray.append(eachpath)

return fileArray

def email_parser(email_path):

punctuations = """,.<>()*&^%$#@!'";~`[]{}|、\\/~+_-=?"""

content_list = readtxt(email_path, 'utf8')

content = (' '.join(content_list)).replace('\r\n', ' ').replace('\t', ' ')

clean_word = []

for punctuation in punctuations:

content = (' '.join(content.split(punctuation))).replace(' ', ' ')

clean_word = [word.lower()

for word in content.split(' ') if len(word) > 2]

return clean_word

def get_word(email_file):

word_list = []

word_set = []

email_paths = fileWalker(email_file)

for email_path in email_paths:

clean_word = email_parser(email_path)

word_list.append(clean_word)

word_set.extend(clean_word)

return word_list, set(word_set)

def count_word_prob(email_list, union_set):

word_prob = {}

for word in union_set:

counter = 0

for email in email_list:

if word in email:

counter += 1

else:

continue

prob = 0.0

if counter != 0:

prob = counter/len(email_list)

else:

prob = 0.01

word_prob[word] = prob

return word_prob

def filter(ham_word_pro, spam_word_pro, test_file):

test_paths = fileWalker(test_file)

for test_path in test_paths:

email_spam_prob = 0.0

spam_prob = 0.5

ham_prob = 0.5

file_name = test_path.split('\\')[-1]

prob_dict = {}

words = set(email_parser(test_path))

for word in words:

Psw = 0.0

if word not in spam_word_pro:

Psw = 0.4

else:

Pws = spam_word_pro[word]

Pwh = ham_word_pro[word]

Psw = spam_prob*(Pws/(Pwh*ham_prob+Pws*spam_prob))

prob_dict[word] = Psw

numerator = 1

denominator_h = 1

for k, v in prob_dict.items():

numerator *= v

denominator_h *= (1-v)

email_spam_prob = round(numerator/(numerator+denominator_h), 4)

if email_spam_prob > 0.5:

print(file_name, 'spam', email_spam_prob)

else:

print(file_name, 'ham', email_spam_prob)

# print(prob_dict)

# print('******************************************************')

# break

def main():

ham_file = r'..\email\ham'

spam_file = r'..\email\spam'

test_file = r'..\email\test'

ham_list, ham_set = get_word(ham_file)

spam_list, spam_set = get_word(spam_file)

union_set = ham_set | spam_set

ham_word_pro = count_word_prob(ham_list, union_set)

spam_word_pro = count_word_prob(spam_list, union_set)

filter(ham_word_pro, spam_word_pro, test_file)

if __name__ == '__main__':

main()

实验结果

ham_24.txt ham 0.0005

ham_3.txt ham 0.0

ham_4.txt ham 0.0

spam_11.txt spam 1.0

spam_14.txt spam 0.9999

spam_17.txt ham 0.0

spam_18.txt spam 0.9992

spam_19.txt spam 1.0

spam_22.txt spam 1.0

spam_8.txt spam 1.0

可见正确率为90%,实际上严格来说,应当将所有数据随机均分十组,每一组轮流作为一次测试集,剩下九组作为训练集,再将十次计算结果求均值,这个模型求出的分类效果才具有可靠性,其次,数据量小导致准确率较小的原因不排除在外。

所有代码以及数据GITHUB

你可能感兴趣的:(基于python的垃圾邮件分类)