使用SVM进行手写识别

导入相应的库
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn import svm
数据的导入与处理
param_info=[["mnist_train.mat","mnist_train_labels.mat","mnist_train","mnist_train_labels",784],["usps_train.mat","usps_train_labels.mat","usps_train","usps_train_labels",256]]


index=1
if index==0:
    dataset_name="MNIST"
    print("MNIST数据集准确度:")
else:
    dataset_name="USPS"
    print("MNIST数据集准确度:")

x_data=sio.loadmat("data/"+param_info[index][0])
y_data=sio.loadmat("data/"+param_info[index][1])
x_data=x_data[param_info[index][2]]
y_data=y_data[param_info[index][3]]



rand_train_index=np.random.choice(len(x_data),size=round(len(x_data)*0.8))
x_train_data=x_data[rand_train_index]
y_train_data=np.array(y_data[rand_train_index])



rand_test_index=np.array(list(set(range(len(x_data)))-set(rand_train_index)))
x_test_data=x_data[rand_test_index]
y_test_data=np.array(y_data[rand_test_index])
调用相应的库,训练数据集
svc_model=svm.SVC(gamma=0.001,C=10,kernel='rbf')
对利用训练集对模型进行训练,利用测试集对模型进行预测
test_acc=[]
train_err=[]
train_acc=[]
#将训练集应用到SVC模型上
for i in range(x_train_data.shape[0]):
    svc_model.fit(x_train_data[0:i+10, :], y_train_data[0:i+10, :].ravel())
    predict=svc_model.predict(x_train_data[:i+10,:])  #预测的结果
    comp=[0 if y1==y2 else 1 for y1,y2 in zip(y_train_data[:i+10,:],predict)]
    temp_train_acc=1-float(sum(comp))/len(y_train_data[:i+10,:])
    train_acc.append(temp_train_acc)
    temp_train_err=1-temp_train_acc
    train_err.append(temp_train_err)
    temp_test_acc=svc_model.score(x_test_data[:, :], y_test_data[:, :].ravel())
    test_acc.append(temp_test_acc)
    if i%100==0:
        print("训练数据的个数:",len(y_train_data[:i+10,:]),"======训练出错的个数:",sum(comp))
        # print("训练出错的个数:",sum(comp))
        print("第"+str(i)+"步:======train_err:"+str(temp_train_err),"========train_acc:"+str(temp_train_acc),"==========test_acc:",str(temp_test_acc))
利用matplotlib库将结果表示出来
plt.plot(test_acc,'r-',label="test_accuracy")
plt.plot(train_acc,'g--',label="train_accuracy")
plt.legend(loc="best")
plt.title("Test_Accuracy And Train_Accuracy(MNIST)")
plt.xlabel("Generation")
plt.ylabel("Accuracy")
plt.show()

plt.plot(train_err,"r-",label="train_loss")
plt.legend(loc="best")
plt.title("Train_Loss(MNIST)")
plt.xlabel("Generation")
plt.ylabel("Loss")
plt.show()

使用SVM进行手写识别_第1张图片

使用SVM进行手写识别_第2张图片

你可能感兴趣的:(Tensorflow)