#!/usr/bin/python#-*- coding: utf-8 -*-##########################################Bayes : 用来描述两个条件概率之间的关系
#参数: inX: vector to compare to existing dataset (1xN)#dataSet: size m data set of known vectors (NxM)#labels: data set labels (1xM vector)#公式:P(A|B)=P(B|A)P(A)/P(B)#输出: 出错率#########################################
importnumpy as npyimportosimporttime#P(B|A)=P(A|B)*P(A)/P(B)
#数据集目录
dataSetDir ='E:/digits/'
classBayes:def __init__(self):
self.length=-1self.labelrate=dict()
self.vectorrate=dict()deffit(self,dataset:list,labels:list):print("训练开始")if len(dataset)!=len(labels):raise ValueError("输入测试数组和类别数组长度不一致")
self.length=len(dataset[0])#训练数据特征值的长度
labelsnum=len(labels) #类别的数量
norlabels=set(labels) #不重复类别的数量
for item innorlabels:
self.labelrate[item]=labels.count(item)/labelsnum #求当前类别占总类别的比例
for vector,label inzip(dataset,labels):if label not inself.vectorrate:
self.vectorrate[label]=[]
self.vectorrate[label].append(vector)print("训练结束")returnselfdefbtest(self,testdata,labelset):if self.length==-1:raise ValueError("未开始训练,先训练")#计算testdata分别为各个类别的概率
lbDict=dict()for thislb inlabelset:
p= 1alllabel=self.labelrate[thislb]
allvector=self.vectorrate[thislb]
vnum=len(allvector)
allvector=npy.array(allvector).Tfor index inrange(0,len(testdata)):
vector=list(allvector[index])
p*=vector.count(testdata[index])/vnum
lbDict[thislb]=p *alllabel
thislbabel=sorted(lbDict,key=lambda x:lbDict[x],reverse=True)[0]returnthislbabel#加载数据
defdatatoarray(fname):
arr=[]
fh=open(fname)for i in range(0,32):
thisline=fh.readline()for j in range(0 , 32):
arr.append(int(thisline[j]))returnarr#建立一个函数取出labels
defseplabel(fname):
filestr=fname.split(".")[0]
label=int(filestr.split("_")[0])returnlabel#建立训练数据
deftraindata():
labels=[]
trainfile=os.listdir(dataSetDir+"trainingDigits") #加载测试数据
num=len(trainfile)
trainarr=npy.zeros((num,1024))for i inrange(num):
thisfname=trainfile[i]
thislabel=seplabel(thisfname)
labels.append(thislabel)
trainarr[i,]=datatoarray(dataSetDir+"trainingDigits/"+thisfname)returntrainarr,labels#贝叶斯算法手写识别主流程
bys=Bayes()
start=time.time()## step 1: 训练数据集
train_data,labels=traindata()
train_data=list(train_data)
bys.fit(train_data,labels)## step 2:测试数据集
thisdata=datatoarray(dataSetDir+"testDigits/8_90.txt")
labelsall=[0,1,2,3,4,5,6,7,8,9]## 识别单个手写体数字#test=bys.btest(thisdata,labelsall)#print(test)
## 识别多个手写体数字(批量处理),并输出结果
testfile=os.listdir(dataSetDir+"testDigits")
num=len(testfile)
x=0for i inrange(num):
thisfilename=testfile[i]
thislabel=seplabel(thisfilename)
thisdataarr=datatoarray(dataSetDir+"testDigits/"+thisfilename)
label=bys.btest(thisdataarr,labelsall)print("测试数字是:"+str(thislabel)+"识别出来的数字是:"+str(label))if label!=thislabel:
x+=1
print("识别出错")print(x)print("出错率:"+str(x/num))
end=time.time()
running_time= end-startprint('程序运行总耗时: %.5f sec' %running_time)