解答:
下面展示完整代码
:
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
import time
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = os.path.join(path,
'%s-labels.idx1-ubyte'
%kind)
images_path = os.path.join(path,
'%s-images.idx3-ubyte'
%kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',
lbpath.read(8))
labels = np.fromfile(lbpath,
dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',
imgpath.read(16))
images = np.fromfile(imgpath,
dtype=np.uint8).reshape(len(labels), 784)
#处理标签成(60000,10)形状
new_labels=[]
for i in labels:
if i==0:
new_labels.append([1,0,0,0,0,0,0,0,0,0])
elif i==1:
new_labels.append([0,1,0,0,0,0,0,0,0,0])
elif i==2:
new_labels.append([0,0,1,0,0,0,0,0,0,0])
elif i==3:
new_labels.append([0,0,0,1,0,0,0,0,0,0])
elif i==4:
new_labels.append([0,0,0,0,1,0,0,0,0,0])
elif i==5:
new_labels.append([0,0,0,0,0,1,0,0,0,0])
elif i==6:
new_labels.append([0,0,0,0,0,0,1,0,0,0])
elif i==7:
new_labels.append([0,0,0,0,0,0,0,1,0,0])
elif i==8:
new_labels.append([0,0,0,0,0,0,0,0,1,0])
else:
new_labels.append([0,0,0,0,0,0,0,0,0,1])
#输出在0-1之间
return images/255, np.array(new_labels)
def sigmoid(z, first_derivative=False):
x=1.0/(1.0+np.exp(-z))
if first_derivative:
return x*(1.0-x)
else:
return x
def verify_validity(x_data,y_data,N):
newh=[]
newz=[]
for j in range(k):
if j==0:
newz.append(np.matmul(x_data[:N],w[j])+np.matmul(np.ones(shape=(N,1)),b[j]))
newh.append(sigmoid(newz[j]))
else:
newz.append(np.matmul(newh[j-1],w[j])+np.matmul(np.ones(shape=(N,1)),b[j]))
newh.append(sigmoid(newz[j]))
for j in range(k):
if j==0:
newz[j]=np.matmul(x_data[:N],w[j])+np.matmul(np.ones(shape=(N,1)),b[j])
newh[j]=sigmoid(newz[j])
else:
newz[j]=np.matmul(newh[j-1],w[j])+np.matmul(np.ones(shape=(N,1)),b[j])
newh[j]=sigmoid(newz[j])
y_predict = np.argmax(newh[k-1], axis=1)
# print("y_predict:",y_predict)
y_actual = np.argmax(y_data[:N], axis=1)
accuracy = np.sum(np.equal(y_predict,y_actual))/len(y_actual)
#损失函数:
loss=np.square(y_data[:N]-newh[k-1]).sum()/(2*N)
#训练过程损失
return loss,accuracy
def visualize_result(save_path,accuracies,losses):
# if not os.path.exists(save_path):
# os.mkdir(r'%s' % save_path)
# Accurary_name="Accuracy_input_dim_%d-hidden_dim_%d-output_dim_%d-num_epochs_%d-N_train_%d-N_test_%d-learning_rate_%f.png"%(input_dim,hidden_dim,output_dim,
# num_epochs,N_train,N_test,learning_rate)
# Loss_name="Loss_input_dim_%d-hidden_dim_%d-output_dim_%d-num_epochs_%d-N_train_%d-N_test_%d-learning_rate_%f.png"%(input_dim,hidden_dim,output_dim,
# num_epochs,N_train,N_test,learning_rate)
#准确率画图
plt.plot(accuracies[:,0],accuracies[:,1],label='Accuracy_train')
plt.plot(accuracies[:,0],accuracies[:,2],label='Accuracy_test')
# plt.title("Accuracy_input_dim=%d,hidden_dim=%d,output_dim=%d,\n num_epochs=%d,N_train=%d,N_test=%d,learning_rate=%f"%(input_dim,hidden_dim,output_dim,
# num_epochs,N_train,N_test,learning_rate))
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(loc='best')
plt.show()
# plt.savefig(os.path.join(save_path,Accurary_name), dpi=300)
# plt.close("all")
#损失函数画图
plt.plot(losses[:,0],losses[:,1],label='Loss_train')
plt.plot(losses[:,0],losses[:,2],label='Loss_test')
# plt.title("Loss_input_dim=%d,hidden_dim=%d,output_dim=%d,\n num_epochs=%d,N_train=%d,N_test=%d,learning_rate=%f"%(input_dim,hidden_dim,output_dim,
# num_epochs,N_train,N_test,learning_rate))
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc='best')
plt.show()
def train():
#初始化w、b
for j in range(k):
w.append(( 2*np.random.random((net_dim[j],net_dim[j+1])) - 1 )/1)
b.append((2*np.random.random((1,net_dim[j+1])) - 1 )/1)
#初始化z、h
for j in range(k):
if j==0:
z.append(np.matmul(x_train[:min_batch],w[j])+np.matmul(np.ones(shape=(min_batch,1)),b[j]))
h.append(sigmoid(z[j]))
else:
z.append(np.matmul(h[j-1],w[j])+np.matmul(np.ones(shape=(min_batch,1)),b[j]))
h.append(sigmoid(z[j]))
#初始化残差delta
for j in range(k):
if j==0:
delta.append(-(y_train[:min_batch]-h[k-1])*sigmoid(z[k-1], first_derivative=True))
else:
delta.append(np.matmul(delta[j-1],w[k-j].T)*sigmoid(z[k-1-j], first_derivative=True))
#计算loss
for i in range(num_epochs):
for num_bat in range(int(N_train/min_batch)):
#前向传播:
for j in range(k):
if j==0:
z[j]=np.matmul(x_train[num_bat*min_batch:(num_bat+1)*min_batch],w[j])+np.matmul(np.ones(shape=(min_batch,1)),b[j])
h[j]=sigmoid(z[j])
else:
z[j]=np.matmul(h[j-1],w[j])+np.matmul(np.ones(shape=(min_batch,1)),b[j])
h[j]=sigmoid(z[j])
#损失函数:
L=np.square(y_train[num_bat*min_batch:(num_bat+1)*min_batch]-h[k-1]).sum()/(2*min_batch)
#反向传播:
#计算残差:
for j in range(k):
if j==0:
delta[j]=-(y_train[num_bat*min_batch:(num_bat+1)*min_batch]-h[k-1])*sigmoid(z[k-1], first_derivative=True)
else:
delta[j]=np.matmul(delta[j-1],w[k-j].T)*sigmoid(z[k-1-j], first_derivative=True)
#权重和偏置更新:
for j in range(k):
if j==0:
w[0] += -learning_rate*np.matmul(x_train[num_bat*min_batch:(num_bat+1)*min_batch].T,delta[k-1])/min_batch
b[0] += -learning_rate*np.matmul(np.ones(shape=(min_batch,1)).T,delta[k-1])/min_batch
else:
w[j] += -learning_rate*np.matmul(h[j-1].T,delta[k-1-j])/min_batch
b[j] += -learning_rate*np.matmul(np.ones(shape=(min_batch,1)).T,delta[k-1-j])/min_batch
# 记录结果
if True:
#训练集准确度
loss_train,accuracy_train = verify_validity(x_train,y_train,N_train)
#测试集准确度
loss_test,accuracy_test = verify_validity(x_test,y_test,N_test)
#训练过程损失
losses.append([i,loss_train,loss_test])
accuracies.append([i,accuracy_train,accuracy_test])
if i%1==0:
print('Epoch: %d Loss:%f Loss_train:%f Loss_test:%f Accuracy_train: %f Accuracy_test: %f' %(i,L,loss_train,loss_test,accuracy_train,accuracy_test))
return np.array(losses),np.array(accuracies)
if __name__ == '__main__':
#读取数据
folder_path='D:\\Code\\Nju_study\\'
x_train,y_train=load_mnist(folder_path+'mnist','train') #(60000,input_dim),(60000,output_dim)
x_test,y_test=load_mnist(folder_path+'mnist','t10k') #(10000,input_dim),(10000,output_dim)
#定义参数
net_dim=[784,128,10]
num_epochs = 50
learning_rate= 1
N_train=60000
min_batch= 100
N_test=10000
k=len(net_dim)-1
#定义变量
losses = []
accuracies=[]
w=[]
b=[]
z=[]
h=[]
delta=[]
#开始训练
print("==============================================================================================================================")
start_time = time.time() #训练开始时间
losses,accuracies=train()
end_time = time.time() #训练结束时间
run_time=end_time-start_time #训练时间,单位为秒
print("本次运行时间:%d h %d m %d s"%(run_time//3600,(run_time-run_time//3600*3600)//60,run_time%60))
#输出结果
visualize_result(folder_path+'learning_rate',accuracies,losses)
下面展示结果
:
==============================================================================================================================
Epoch: 0 Loss:0.208053 Loss_train:0.187791 Loss_test:0.182892 Accuracy_train: 0.739783 Accuracy_test: 0.748900
Epoch: 1 Loss:0.144441 Loss_train:0.149888 Loss_test:0.145976 Accuracy_train: 0.789767 Accuracy_test: 0.799400
Epoch: 2 Loss:0.119358 Loss_train:0.134842 Loss_test:0.131591 Accuracy_train: 0.807383 Accuracy_test: 0.815000
Epoch: 3 Loss:0.104755 Loss_train:0.125744 Loss_test:0.123176 Accuracy_train: 0.817583 Accuracy_test: 0.822900
Epoch: 4 Loss:0.094766 Loss_train:0.119273 Loss_test:0.117325 Accuracy_train: 0.825100 Accuracy_test: 0.829200
Epoch: 5 Loss:0.086082 Loss_train:0.105821 Loss_test:0.105401 Accuracy_train: 0.885467 Accuracy_test: 0.885300
Epoch: 6 Loss:0.050557 Loss_train:0.067779 Loss_test:0.067928 Accuracy_train: 0.925033 Accuracy_test: 0.924100
Epoch: 7 Loss:0.044383 Loss_train:0.062873 Loss_test:0.063666 Accuracy_train: 0.930767 Accuracy_test: 0.928500
Epoch: 8 Loss:0.040248 Loss_train:0.059338 Loss_test:0.060657 Accuracy_train: 0.934583 Accuracy_test: 0.932600
Epoch: 9 Loss:0.037044 Loss_train:0.056478 Loss_test:0.058259 Accuracy_train: 0.938233 Accuracy_test: 0.935400
Epoch: 10 Loss:0.034462 Loss_train:0.054045 Loss_test:0.056255 Accuracy_train: 0.941000 Accuracy_test: 0.937800
Epoch: 11 Loss:0.032333 Loss_train:0.051919 Loss_test:0.054534 Accuracy_train: 0.943317 Accuracy_test: 0.939400
Epoch: 12 Loss:0.030551 Loss_train:0.050028 Loss_test:0.053031 Accuracy_train: 0.945450 Accuracy_test: 0.941500
Epoch: 13 Loss:0.029036 Loss_train:0.048329 Loss_test:0.051698 Accuracy_train: 0.947417 Accuracy_test: 0.942600
Epoch: 14 Loss:0.027730 Loss_train:0.046789 Loss_test:0.050504 Accuracy_train: 0.949033 Accuracy_test: 0.943700
Epoch: 15 Loss:0.026586 Loss_train:0.045385 Loss_test:0.049423 Accuracy_train: 0.950333 Accuracy_test: 0.945100
Epoch: 16 Loss:0.025569 Loss_train:0.044096 Loss_test:0.048438 Accuracy_train: 0.951733 Accuracy_test: 0.945600
Epoch: 17 Loss:0.024654 Loss_train:0.042906 Loss_test:0.047534 Accuracy_train: 0.953167 Accuracy_test: 0.946400
Epoch: 18 Loss:0.023821 Loss_train:0.041800 Loss_test:0.046700 Accuracy_train: 0.954550 Accuracy_test: 0.948400
Epoch: 19 Loss:0.023059 Loss_train:0.040768 Loss_test:0.045927 Accuracy_train: 0.955767 Accuracy_test: 0.949500
Epoch: 20 Loss:0.022358 Loss_train:0.039799 Loss_test:0.045207 Accuracy_train: 0.956983 Accuracy_test: 0.950000
Epoch: 21 Loss:0.021711 Loss_train:0.038887 Loss_test:0.044533 Accuracy_train: 0.957733 Accuracy_test: 0.950800
Epoch: 22 Loss:0.021112 Loss_train:0.038026 Loss_test:0.043902 Accuracy_train: 0.958750 Accuracy_test: 0.951700
Epoch: 23 Loss:0.020555 Loss_train:0.037211 Loss_test:0.043308 Accuracy_train: 0.959750 Accuracy_test: 0.951900
Epoch: 24 Loss:0.020036 Loss_train:0.036437 Loss_test:0.042749 Accuracy_train: 0.960633 Accuracy_test: 0.952300
Epoch: 25 Loss:0.019553 Loss_train:0.035702 Loss_test:0.042221 Accuracy_train: 0.961267 Accuracy_test: 0.952800
Epoch: 26 Loss:0.019101 Loss_train:0.035001 Loss_test:0.041721 Accuracy_train: 0.962167 Accuracy_test: 0.953600
Epoch: 27 Loss:0.018678 Loss_train:0.034333 Loss_test:0.041249 Accuracy_train: 0.962850 Accuracy_test: 0.954500
Epoch: 28 Loss:0.018283 Loss_train:0.033695 Loss_test:0.040801 Accuracy_train: 0.963783 Accuracy_test: 0.955100
Epoch: 29 Loss:0.017914 Loss_train:0.033085 Loss_test:0.040377 Accuracy_train: 0.964517 Accuracy_test: 0.955400
Epoch: 30 Loss:0.017570 Loss_train:0.032500 Loss_test:0.039975 Accuracy_train: 0.965250 Accuracy_test: 0.955900
Epoch: 31 Loss:0.017248 Loss_train:0.031940 Loss_test:0.039594 Accuracy_train: 0.965850 Accuracy_test: 0.956700
Epoch: 32 Loss:0.016949 Loss_train:0.031402 Loss_test:0.039231 Accuracy_train: 0.966417 Accuracy_test: 0.957100
Epoch: 33 Loss:0.016670 Loss_train:0.030886 Loss_test:0.038886 Accuracy_train: 0.966950 Accuracy_test: 0.957600
Epoch: 34 Loss:0.016410 Loss_train:0.030389 Loss_test:0.038559 Accuracy_train: 0.967633 Accuracy_test: 0.958300
Epoch: 35 Loss:0.016168 Loss_train:0.029911 Loss_test:0.038246 Accuracy_train: 0.968100 Accuracy_test: 0.958600
Epoch: 36 Loss:0.015943 Loss_train:0.029449 Loss_test:0.037949 Accuracy_train: 0.968783 Accuracy_test: 0.959100
Epoch: 37 Loss:0.015733 Loss_train:0.029004 Loss_test:0.037665 Accuracy_train: 0.969267 Accuracy_test: 0.959300
Epoch: 38 Loss:0.015537 Loss_train:0.028574 Loss_test:0.037394 Accuracy_train: 0.969783 Accuracy_test: 0.959600
Epoch: 39 Loss:0.015353 Loss_train:0.028158 Loss_test:0.037135 Accuracy_train: 0.970267 Accuracy_test: 0.959800
Epoch: 40 Loss:0.015181 Loss_train:0.027756 Loss_test:0.036888 Accuracy_train: 0.970617 Accuracy_test: 0.960100
Epoch: 41 Loss:0.015018 Loss_train:0.027367 Loss_test:0.036651 Accuracy_train: 0.970967 Accuracy_test: 0.960100
Epoch: 42 Loss:0.014863 Loss_train:0.026989 Loss_test:0.036424 Accuracy_train: 0.971450 Accuracy_test: 0.960200
Epoch: 43 Loss:0.014717 Loss_train:0.026623 Loss_test:0.036206 Accuracy_train: 0.972050 Accuracy_test: 0.960500
Epoch: 44 Loss:0.014577 Loss_train:0.026268 Loss_test:0.035997 Accuracy_train: 0.972433 Accuracy_test: 0.960700
Epoch: 45 Loss:0.014443 Loss_train:0.025923 Loss_test:0.035797 Accuracy_train: 0.972783 Accuracy_test: 0.961000
Epoch: 46 Loss:0.014315 Loss_train:0.025588 Loss_test:0.035604 Accuracy_train: 0.973050 Accuracy_test: 0.961100
Epoch: 47 Loss:0.014192 Loss_train:0.025262 Loss_test:0.035419 Accuracy_train: 0.973350 Accuracy_test: 0.961100
Epoch: 48 Loss:0.014074 Loss_train:0.024944 Loss_test:0.035241 Accuracy_train: 0.973633 Accuracy_test: 0.961100
Epoch: 49 Loss:0.013960 Loss_train:0.024635 Loss_test:0.035070 Accuracy_train: 0.974017 Accuracy_test: 0.961300
本次运行时间:0 h 2 m 20 s
注意事项:
1.只需要改变文件路径就可以见下图对应复制代码即可用,能读到mnist数据集就可(注意解压mnist,注意后缀!!!读的是.idx3-ubyte文件)
2.BP算法(自己查资料)写MLP别忘了从随机梯度下降到批量梯度下降,权重矩阵相乘之后取平均数